# Einsum Puzzles
This Colab will teach you how to use numpy.einsum through examples.

In [1]:
# einsum is available in numpy, torch and JAX, so the following would all work:
# np.einsum, torch.einsum, jnp.einsum
# Let's start with numpy.

import numpy as np

def check_answer(expected, given):
  if expected.shape != given.shape:
    print("Shape mismatch!")
    print(f"Expected shape: {expected.shape}")
    print(f"Given shape: {given.shape}")
    return
  if np.allclose(expected, given):
    print("Correct!")
    return True
  else:
    print("Incorrect!")
    return False

In [3]:
# This visualization code is experimental and only works for tensors of rank up to 2.

from IPython.display import SVG, display

def visualize_einsum(
    einsum_str,
    tensor_names=None,
    tensor_shapes=None,
    result_shape=None,
    result_rank=None,
    width_px=300):
    """
    Generate a SVG visualization of tensors and their dimensions, with configurable width.
    Ensures complete visibility of the visualization.
    """
    # Parse the einsum string
    input_output = einsum_str.split("->")
    inputs = input_output[0].split(",")
    output = input_output[1] if len(input_output) > 1 else ""

    if tensor_names is None:
        tensor_names = [chr(65 + i) for i in range(len(inputs))]
    if len(tensor_names) < len(inputs):
        tensor_names.extend([chr(65 + i) for i in range(len(tensor_names), len(inputs))])

    # SVG parameters with explicit aspect ratio control
    tensor_width_wide = 60
    tensor_height_wide = 60
    tensor_width_narrow = tensor_width_wide // 3
    tensor_height_narrow = tensor_height_wide // 3
    arrow_margin = 30
    label_margin = 20
    spacing = 100

    # Calculate total dimensions needed
    total_width = spacing * (len(inputs) + 1) + arrow_margin * 2
    total_height = tensor_height_wide + arrow_margin * 2 + label_margin * 2

    # Scale factor based on desired width
    scale_factor = width_px / total_width
    scaled_height = int(total_height * scale_factor)

    svg = f'''<svg viewBox="0 0 {total_width} {total_height}"
              xmlns="http://www.w3.org/2000/svg"
              width="{width_px}"
              height="{scaled_height}"
              style="margin: 10px;">\n'''

    svg += '''    <defs>
        <marker id="arrowhead" markerWidth="6" markerHeight="4"
                refX="6" refY="2" orient="auto">
            <polygon points="0 0, 6 2, 0 4" fill="black"/>
        </marker>
    </defs>\n'''

    # Starting positions
    start_x = arrow_margin
    start_y = arrow_margin

    # Draw input tensors
    for i, (indices, name) in enumerate(zip(inputs, tensor_names)):
        tensor_width = tensor_width_wide
        tensor_height = tensor_height_wide

        shape = None
        if tensor_shapes is not None:
            shape = tensor_shapes[i]

        is_vector = len(indices) == 1
        if is_vector and shape is None or shape == "v":
            tensor_width = tensor_width_narrow
        if is_vector and shape == "h":
            tensor_height = tensor_height_narrow

        x = start_x + i * spacing

        # Rectangle
        svg += f'    <rect x="{x}" y="{start_y}" width="{tensor_width}" height="{tensor_height}" '
        svg += 'fill="#e6f3ff" stroke="black" stroke-width="1"/>\n'

        # Name
        svg += f'    <text x="{x + tensor_width/2}" y="{start_y + tensor_height/2}" '
        svg += 'text-anchor="middle" dominant-baseline="middle" '
        svg += f'font-size="14">{name}</text>\n'

        # Dimension arrows and labels
        for j, idx in enumerate(indices):
            if j == 0:  # Vertical
                svg += f'    <line x1="{x-20}" y1="{start_y}" x2="{x-20}" '
                svg += f'y2="{start_y+tensor_height}" stroke="black" '
                svg += 'marker-end="url(#arrowhead)"/>\n'
                svg += f'    <text x="{x-10}" y="{start_y + tensor_height/2}" '
                svg += f'text-anchor="middle" font-size="12">{idx}</text>\n'
            else:  # Horizontal
                svg += f'    <line x1="{x}" y1="{start_y+tensor_height+20}" '
                svg += f'x2="{x+tensor_width}" y2="{start_y+tensor_height+20}" '
                svg += 'stroke="black" marker-end="url(#arrowhead)"/>\n'
                svg += f'    <text x="{x + tensor_width/2}" y="{start_y+tensor_height+35}" '
                svg += f'text-anchor="middle" font-size="12">{idx}</text>\n'

    # Draw output tensor
    is_scalar = result_rank is not None and result_rank == 0
    if output or is_scalar:
        tensor_width = tensor_width_wide
        tensor_height = tensor_height_wide
        name = "r"

        is_vector = len(output) == 1
        if is_vector and result_shape is None or result_shape == "v":
            tensor_width = tensor_width_narrow
        if is_vector and result_shape == "h":
            tensor_height = tensor_height_narrow

        if is_scalar:
            # Special case of a scalar output
            tensor_height = tensor_height_narrow
            tensor_width = tensor_width_narrow

        x = start_x + len(inputs) * spacing

        svg += f'    <rect x="{x}" y="{start_y}" width="{tensor_width}" height="{tensor_height}" '
        svg += 'fill="#ffe6e6" stroke="black" stroke-width="1"/>\n'

        svg += f'    <text x="{x + tensor_width/2}" y="{start_y + tensor_height/2}" '
        svg += f'text-anchor="middle" dominant-baseline="middle" font-size="14">{name}</text>\n'

        if not is_scalar:
            for j, idx in enumerate(output):
                if j == 0 and result_shape != "h":  # Vertical
                    svg += f'    <line x1="{x-20}" y1="{start_y}" x2="{x-20}" '
                    svg += f'y2="{start_y+tensor_height}" stroke="black" '
                    svg += 'marker-end="url(#arrowhead)"/>\n'
                    svg += f'    <text x="{x-10}" y="{start_y + tensor_height/2}" '
                    svg += f'text-anchor="middle" font-size="12">{idx}</text>\n'
                else:  # Horizontal
                    svg += f'    <line x1="{x}" y1="{start_y+tensor_height+20}" '
                    svg += f'x2="{x+tensor_width}" y2="{start_y+tensor_height+20}" '
                    svg += 'stroke="black" marker-end="url(#arrowhead)"/>\n'
                    svg += f'    <text x="{x + tensor_width/2}" y="{start_y+tensor_height+35}" '
                    svg += f'text-anchor="middle" font-size="12">{idx}</text>\n'

    svg += '</svg>'
    return SVG(svg)

In [4]:
# A version od einsum that also validates correctness of the result.

import numpy as np
import inspect

def einsum(einsum_expr,
           *tensors,
           expected=None,
           result_shape=None):
    # Get the calling frame to inspect variable names
    frame = inspect.currentframe().f_back

    # Extract variable names from the calling context
    context_vars = frame.f_locals

    # Find the actual variable names by matching values
    tensor_names = []
    for tensor in tensors:
        for var_name, var_value in context_vars.items():
            if var_value is tensor:
                tensor_names.append(var_name)
                break
        else:
            # If we couldn't find the variable name, use a placeholder
            tensor_names.append(f"tensor_{len(tensor_names)}")

    # Perform the einsum operation
    result = np.einsum(einsum_expr, *tensors)

    if expected is not None and not check_answer(expected, result):
      # Don't attempt visualization if the result is incorrect
      return result

    # Visualize the operation
    result_rank = None
    if expected is not None:
      result_rank = expected.ndim
    svg = visualize_einsum(einsum_expr,
                           tensor_names=tensor_names,
                           result_shape=result_shape,
                           result_rank=result_rank)
    if svg:
      display(svg)

    return result

`einsum` allows you to create operations in which multiple inputs (vectors, arrays, or more generally, tensors of any rank) are combined to create a single result. A typical call would look like this:
```
d = np.einsum("◻◻,◻◻◻,◻◻->◻◻", a, b, c)
```
where `a`, `b`, and `c` are inputs and `d` is the result. The subscripts string `"◻◻,◻◻◻,◻◻->◻◻"` has two parts separated by `->`:
- The left part describes the dimensions of the inputs. There will be one description per input and they are separated by commas.
- The right part describes the dimensions of the result.

Here's our first example:

In [None]:
vector_length = 5
u = np.arange(vector_length)
print(f"{u.shape=}")
print(f"u: {u}")

r = np.einsum("i->i", u)
print(f"{r.shape=}")
print(f"r: {r}")

In the example above, there is one input `u`, which is a vector. The left part of the subscripts string is `"i"`. The use of a single letter means that the input has one dimension. The output is also specified as `"i"`, so it ends up being identical to the input.

The choice of the letter does't matter, so instead of `"i"` we could have used `"x"` or any other letter, and the result would be the same:

In [None]:
r = np.einsum("x->x", u)
print(f"r: {r}")

There is a shortcut for this operation. If we completely omit the `->` symbol and the right hand side, we can get as output the same vector as the input.

In [None]:
r = np.einsum("i", u)
print(f"r: {r}")

Another important feature of `einsum` is that if the same letter appears twice on the left-hand side, then corresponding elements from the two inputs on the dimension described by that letter will be pairwise multiplied by each other.

In the example below, the result, `r`, is a vector of the same length as the two inputs `u` and `v`, and each of its elements is a product of two corresponding elements of `u` and `v`. That is, we want to compute `r` such that:

$$
r_i = u_i v_i
$$

In [None]:
print(f"u: {u}")
v = np.arange(vector_length, 2 * vector_length)
print(f"v: {v}")
print(f"{v.shape=}")

r = np.einsum("i,i->i", u, v)
print(f"r: {r}")

# Verifying puzzle solutions

In the rest of the tutorial instead of using the numpy function `np.einsum`, we will use the function `einsum` defined earlier in this notebook. This function also performs validation with the expected answer, so we can check in the puzzles if the provided answer was correct. For example we can check if the sum of all elements in one the earlier examples is correct.

In [None]:
# Compute the expected result using another method, in this case the np.sum function.
r = einsum("i->", u, expected=np.sum(u))

# Puzzle 1: inner product

Use your knowledge from the previous two examples to compute the inner product of `u` and `v`, that is, compute the sum of the pairwise products of the elements of `u` and `v`. Remember that:
- Repeating the same index twice on the left-hand side of `->` means that elements will be pairwise multiplied.
- Omitting an index on the right-hand side of `->` means that we will sum all values along that dimension.

The inner product is also implemented as `np.inner` and we will use that function to check the correctness of the answer.

The result should be:

$$
r = \sum_i u_i v_i
$$

In [None]:
print(f"u: {u}")
print(f"v: {v}")

# Fill in the first argument to the np.einsum call below:
r = einsum("", u, v, expected=np.inner(u, v))

The same principle applies to matrices. Consider this snippet, which doesn't change the input.

In [None]:
rows, cols = 3, 4
a = np.arange(rows * cols).reshape(rows, cols)
print(f"{a.shape=}")
print(f"a:\n{a}")

r = einsum("ij->ij", a)
print(f"{r.shape=}")
print(f"r:\n{r}")



# Puzzle 2: matrix transpose

Since the order of letters in the subscripts string determines the order of dimensions, you can take advantage of that to permute dimensions of a tensor. How would you use that to transpose a matrix? Fill in the subscripts string below to solve this puzzle.

In [None]:
# Fill in the first argument to the np.einsum call below:
r = einsum("", a, expected=a.T)

To summarize what we've learned so far:

- Letters that appear in the subscripts string both to the left and right of the `"->"` arrow are called *free indices*. They indicate no change in the dimensions denoted by those indices.

- If an index appears to the left of `"->"` but not to the right, it is called a *summation index* and `einsum` will add all values along that dimension. For example, we can use this to calculate the sum of all elements of our vector `u`.

- If an index appears twice on the left hand side, we will perform pairwise multiplication of the corresponding elements.

Consider another summation example. This time, we have a matrix with 3 rows and 4 columns and we want to add all elements in each row. The result is a vector of three elements, in which the first element is the sum of all 4 elements in the first row of the array, and so on:

$$
r_i = \sum_j a_{ij}
$$

We do this by not writing the column dimension (denoted by `j`) on the right-hand side:

In [None]:
rows, cols = 3, 4
a = np.arange(rows * cols).reshape(rows, cols)
print(f"{a.shape=}")
print(f"a:\n{a}")

r = einsum("ij->i", a)
print(f"{r.shape=}")
print(f"r:\n{r}")

# Puzzle 3: Summing columns of an array

How would you create a vector containing the sum of each column in the matrix? For our matrix with 3 rows and 4 columns, the result should be a vector of 4 elements, in which each element is the sum of the values in the corresponding column.

$$
r_j = \sum_i a_{ij}
$$

In [None]:
# Fill in the first argument to the np.einsum call below:
r = einsum("", a, expected=np.sum(a, axis=0))
print(f"{r.shape=}")
print(f"r:\n{r}")

# A note on vectors in this colab

Some of the examples operate on vectors and matrices. Depending on the type of operation, the vector will be either a row vector or a column vector. However, when Python prints a numpy vector, it is always printed in a single line of text, which makes it look like a row vector. If you want to figure out whether the vector is a row or column vector, you will need to use your knowledge of linear algebra.

The two most common operations we will encounter are:
- Multiplication of a vector by matrix: The vector is a *row vector* with length equal to the number of rows of the matrix. The result is a *row vector*.
- Multiplication of a matrix by a vector: The vector is a *column vector* with length equal to the number of columns of the matrix. The result is a *column vector*.

In [None]:
# Create a matrix and two vectors with lengths matching the number of rows and
# the number of columns in the matrix, respectively.

rows, cols = 3, 4
a = np.arange(rows * cols).reshape(rows, cols)
print(f"{a.shape=}")
print(f"a:\n{a}")

u = np.arange(cols)
print(f"{u.shape=}")
print(f"u: {u}")

w = np.arange(rows)
print(f"{w.shape=}")
print(f"w: {w}")

print("\nVector multiplied by matrix (the result is a row vector)")
wxa = w @ a
print(f"{wxa.shape=}")
print(f"wxa: {wxa}")


print("\nMatrix multiplied by vector (the result is a column vector)")
axu = a @ u
print(f"{axu.shape=}")
print(f"axu: {axu}")

print("\nUse matrix transpose to compute wxa as a column vector.")
wxa2 = a.T @ w
print(f"{wxa2.shape=}")
print(f"wxa2:\n{wxa2}")
print("In math, vectors wxa and wxa2 are row and column vectors, respectively.")
print("But in Python, they're printed the same way and have the same shape.")

print("\nThese two vectors are equal to each other.")
print(f"{np.array_equal(wxa, wxa2)=}")

# Puzzle 4: Multiply a matrix by a vector

How would you multiply a matrix by a vector? The vector has length equal to the number of columns in the matrix and the result is a vector with length equal to the number of rows of the matrix.

So, the $i$th element of the result can be expressed as:
$$
r_i = \sum_j a_{ij}u_j
$$

In [None]:
# Fill in the first argument to the np.einsum call below:
r = einsum("", a, u, expected=a @ u)
print(f"{r.shape=}")
print(f"r:\n{r}")

# Puzzle 5: Matrix multiplication

Given what we've learned so far about `einsum`, matrix multiplication can be expressed in an elegant way. We need to perform an inner product along two matching dimensions, which we already did in Puzzle 1. The other two dimensions are unchanged.

In [None]:
rows2 = cols
cols2 = 5
b = np.arange(rows2 * cols2).reshape(rows2, cols2)

# Fill in the first argument to the np.einsum call below.
r = einsum("", a, b, expected=a @ b)
print(f"{r.shape=}")
print(f"r:\n{r}")

# Puzzle 6: Part of the attention kernel

We will now look at a formula from the [Attention Is All You Need](https://arxiv.org/abs/1706.03762) paper. A part of the implementation of transformers computes this formula:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right)V
$$

For this puzzle, try to implement a small part of this formula:

$$
S = QK^T
$$

The nice thing about using `einsum` is that the whole operation can be performed in a single call. Without `einsum`, we need to first transpose $K$ and then multiply that with $Q$.

In [None]:
# Assume that Q and K are matrices (tensors of rank 2).

np.random.seed(42)
Q = np.random.rand(4, 5)  # Query matrix (4 examples, 5 dimensions)
K = np.random.rand(4, 5)  # Key matrix (4 examples, 5 dimensions)

# Fill in the first argument to the np.einsum call below. It should be similar
# to what you did for matrix multiplication.
scores = einsum("", Q, K, expected=Q @ K.T)
print(f"{scores.shape=}")
print(f"scores:\n{scores}")