Mental model to keep in mind

•  Each letter = one axis.
•  A letter that appears:
◦  In multiple operands and not in output → that axis is summed over.
◦  In operands and in output → that axis is kept (not summed).
◦  Only once in the whole expression → that axis is just carried through.
•  Implicit mode: 'ij,jk' → output indices are all letters that are not summed, in alphabetical order.
•  Explicit mode: 'ij,jk->ik' → you control output order explicitly.
•  ... (ellipsis) = “all the leftover axes here”.

### Exercises

In [1]:
import numpy as np

# rng = np.random.default_rng(0)

0.1. Create a 1D array a of length 5.  
Task:  
•  Write an einsum call that is exactly a no-op view: same data, no sum, no axis reordering.  
◦  Check: np.einsum(?, a) is a (same shape, and is gives False/True?).

0.2. Same a.  
Task:  
•  Sum all elements of a using einsum.  
◦  Compare to np.sum(a).

0.3. Create a 2D array A of shape (3, 4).  
Tasks:  
•  Sum all elements with einsum in two different ways (one implicit, one explicit).  
◦  Compare to np.sum(A).

0.4. For A (3, 4):  
Tasks:  
•  Sum over rows (axis 0) with einsum. Compare to np.sum(A, axis=0).  
•  Sum over columns (axis 1) with einsum. Compare to np.sum(A, axis=1).

In [2]:
a = np.arange(5)
print(a)

[0 1 2 3 4]


In [3]:
#1 
ans = np.einsum('i', a)
print(ans)

[0 1 2 3 4]


In [4]:
sum_all = np.einsum('i->', a)
print(sum_all)

10


In [5]:
b = np.arange(12).reshape(3,4)
print(b)

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]


In [6]:
# 3 rows, 4 columns
# index 'i' for rows, index 'j' for columns
einsum_sum_rows = np.einsum('ij->i', b)
ref_sum_rows = b.sum(axis=1)
print(einsum_sum_rows)
print(np.allclose(einsum_sum_rows, ref_sum_rows), einsum_sum_rows.shape, ref_sum_rows.shape)


einsum_sum_cols = np.einsum('ij->j', b)
ref_sum_cols = b.sum(axis=0)
print(einsum_sum_cols)
print(np.allclose(einsum_sum_cols, ref_sum_cols), einsum_sum_cols.shape, ref_sum_cols.shape)


einsum_total_sum = np.einsum('ij->', b)
ref_total_sum = b.sum()
print(einsum_total_sum)
print(np.allclose(einsum_total_sum, ref_total_sum), einsum_total_sum.shape, ref_total_sum.shape)

[ 6 22 38]
True (3,) (3,)
[12 15 18 21]
True (4,) (4,)
66
True () ()


Level 1 – Vectors, inner/outer, and explicit vs implicit

Use vectors x, y of length 4: x = rng.normal(size=4); y = rng.normal(size=4).

Exercises

1.1. Inner product of two vectors  
•  Use einsum to compute x·y.  
•  Compare to np.dot(x, y) and np.inner(x, y).

1.2. Outer product of two vectors  
•  Use einsum to compute the outer product of x and y.  
•  Compare to np.outer(x, y).

1.3. Squared L2 norm  
•  Use einsum to compute ||x||² = ∑ᵢ xᵢ² in two ways:
◦  Using one operand.
◦  Using x twice as operands.  
•  Compare to np.sum(x**2).

1.4. Implicit vs explicit  
Take A of shape (3, 4).

•  Write an implicit expression 'ij' (no ->) and inspect np.einsum('ij', A).shape.
•  Write an implicit expression 'ji' and inspect the shape.
•  Write an explicit expression 'ij->ji' and compare to A.T.  
Explain to yourself why 'ij' and 'ij->ij' are different (shape same, but rules differ for summation vs no op).


In [7]:
x = np.arange(4).reshape(4)
y = np.arange(4).reshape(4)
print(x)
print(y)

[0 1 2 3]
[0 1 2 3]


In [8]:
ref_dot = np.dot(x, y)
print(ref_dot)

14


In [9]:
#  If you include '->out_sub' part, that's explicit mode: you control the exact output labels and their order.
#  If you omit '->', einsum uses implicit mode: it follows Einstein summation rules and outputs axes ordered alphabetically by label (this can reorder axes).
epinsum_dot = np.einsum('i,i->', x, y)
print(epinsum_dot)
# but
epinsum_dot_columnwise = np.einsum('i,i->i', x, y)
print(epinsum_dot_columnwise)
# alphabetic order of indices matters
epinsum_dot_columnwise_alpha = np.einsum('i,i', x, y)
print(epinsum_dot_columnwise_alpha)

14
[0 1 4 9]
14


In [10]:
# 1.2. Outer product of two vectors  
ref_outer = np.outer(x, y)
print(ref_outer)
epinsum_outer = np.einsum('i,j->i',x,y)
print(epinsum_outer)

[[0 0 0 0]
 [0 1 2 3]
 [0 2 4 6]
 [0 3 6 9]]
[ 0  6 12 18]


In [11]:
# 1.3. Squared L2 norm  
z = np.arange(12).reshape(3,4)
l2_norm_ref = np.sum(z**2)
print(l2_norm_ref)
einsum_l2_sum = np.einsum('ij,ij->',z,z)
print(einsum_l2_sum)

506
506


In [12]:
# shapes
# same shape
print(np.einsum('ij',z).shape)
print(np.einsum('ij->ij',z).shape)
# transposed shape
print(np.einsum('ji',z).shape)
print(np.einsum('ij->ji',z).shape)


(3, 4)
(3, 4)
(4, 3)
(4, 3)


In [13]:
m = np.arange(16).reshape(4,4)
print(m)
# hint: diagnal matrix
print(np.einsum('ii', m))
print(np.einsum('jj', m))
print(np.einsum('ii->i', m))
print(np.einsum('jj->j', m))

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
30
30
[ 0  5 10 15]
[ 0  5 10 15]


Level 2 – Matrix–vector, matrix–matrix, and batched matmul

In [14]:
A = np.arange(12).reshape(3,4)
B = np.arange(20).reshape(4,5)
v = np.arange(4).reshape(4)
print(A)
print(B)
print(v)

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]]
[0 1 2 3]


In [15]:
# 2.1. Matrix–vector product  
ref_mat_vec_prod = np.matmul(A,v) # 3,4 @ 4 = 3
print(ref_mat_vec_prod)
episum_mat_vec_prod = np.einsum('ij,j->i',A,v) # not ij,i -> i
print(episum_mat_vec_prod)


[14 38 62]
[14 38 62]


In [16]:
# matrix-matrix product
ref_mat_mat_prod = np.matmul(A,B) # 3,4 @ 4,5 = 3,5
print(ref_mat_mat_prod)
eisum_mat_mat_prod = np.einsum('ij,jk->ik',A,B)
print(eisum_mat_mat_prod)
print(np.allclose(ref_mat_mat_prod, eisum_mat_mat_prod), eisum_mat_mat_prod.shape, ref_mat_mat_prod.shape)

[[ 70  76  82  88  94]
 [190 212 234 256 278]
 [310 348 386 424 462]]
[[ 70  76  82  88  94]
 [190 212 234 256 278]
 [310 348 386 424 462]]
True (3, 5) (3, 5)


In [17]:
# Batched matrix–matrix product
X_BLD = np.arange(120).reshape(10,3,4) 
y_BDF = np.arange(200).reshape(10,4,5)
v_BD = np.arange(40).reshape(10, 4)


In [18]:
ref_batched_mat_mat_prod_BDF = np.matmul(X_BLD,y_BDF) # 10,3,4 @ 10,4,5 = 10,3,5
print(ref_batched_mat_mat_prod_BDF.shape)
einsum_batched_mat_mat_prod_BDF = np.einsum('ijk,ikm->ijm',X_BLD,y_BDF)
print(einsum_batched_mat_mat_prod_BDF.shape)

(10, 3, 5)
(10, 3, 5)


## Level 3 – Diagonals, traces, and permutations

In [19]:
A_sq_mat = np.arange(16).reshape(4,4)
print(A_sq_mat)

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]


In [20]:
#  3.1. Trace of a matrix  
ref_trace= np.trace(A_sq_mat)
print(ref_trace)
einsum_trace = np.einsum('ii',A_sq_mat)
print(einsum_trace)

30
30


In [21]:
# 3.2. Extract diagonal as a vector  
ref_diagnal = np.diagonal(A_sq_mat)
print(ref_diagnal)
einsum_diagnal = np.einsum('ii->i',A_sq_mat)
print(einsum_diagnal)

[ 0  5 10 15]
[ 0  5 10 15]


In [22]:
d = np.arange(4).reshape(4)
print(d)

[0 1 2 3]


In [23]:
# 3.3. Construct diagonal matrix from a vector  
ref_diag_matrix = np.diag(d)
print(ref_diag_matrix)
print(np.eye(len(d)))
# einsum_diag_matrix = np.einsum('i->ii',np.eye(len(d)))
# print(einsum_diag_matrix)

### UNABLE TO DO THIS WITH EINSUM

[[0 0 0 0]
 [0 1 0 0]
 [0 0 2 0]
 [0 0 0 3]]
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]


In [24]:
x_sq_mat_BDD = np.arange(160).reshape(10,4,4)

In [25]:
## 3.5. Diagonal  along the first two axes  of a 3D tensor  
ref_diag_x_sq_mat_BDD = np.diagonal(x_sq_mat_BDD, axis1=1, axis2=2)
print(ref_diag_x_sq_mat_BDD)
einsum_diag_x_sq_mat_BDD = np.einsum('bii->bi', x_sq_mat_BDD)
print(np.einsum('...ii->...i', x_sq_mat_BDD))
print(einsum_diag_x_sq_mat_BDD)
print(np.allclose(einsum_diag_x_sq_mat_BDD, ref_diag_x_sq_mat_BDD), einsum_diag_x_sq_mat_BDD.shape, ref_diag_x_sq_mat_BDD.shape)

[[  0   5  10  15]
 [ 16  21  26  31]
 [ 32  37  42  47]
 [ 48  53  58  63]
 [ 64  69  74  79]
 [ 80  85  90  95]
 [ 96 101 106 111]
 [112 117 122 127]
 [128 133 138 143]
 [144 149 154 159]]
[[  0   5  10  15]
 [ 16  21  26  31]
 [ 32  37  42  47]
 [ 48  53  58  63]
 [ 64  69  74  79]
 [ 80  85  90  95]
 [ 96 101 106 111]
 [112 117 122 127]
 [128 133 138 143]
 [144 149 154 159]]
[[  0   5  10  15]
 [ 16  21  26  31]
 [ 32  37  42  47]
 [ 48  53  58  63]
 [ 64  69  74  79]
 [ 80  85  90  95]
 [ 96 101 106 111]
 [112 117 122 127]
 [128 133 138 143]
 [144 149 154 159]]
True (10, 4) (10, 4)


### Level 4 – Broadcasting, elementwise ops, ellipsis

In [50]:
X_BD = np.arange(200).reshape(10,20)
y_BD = np.arange(200).reshape(10,20)
b_B = np.arange(20).reshape(20)
print(X_BD)

[[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
   18  19]
 [ 20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37
   38  39]
 [ 40  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57
   58  59]
 [ 60  61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77
   78  79]
 [ 80  81  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97
   98  99]
 [100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
  118 119]
 [120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
  138 139]
 [140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
  158 159]
 [160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
  178 179]
 [180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
  198 199]]


In [51]:
# 4.1. Elementwise multiplication (no reduction)  
# •  Use einsum to compute elementwise product X * Y.  
# •  Compare to X * Y.
ref_elementwise_mul_Xy = X_BD*y_BD
print(ref_elementwise_mul_Xy)
einsum_elementwise_mul_Xy = np.einsum('...i,...i->...i',X_BD,y_BD)
print(einsum_elementwise_mul_Xy)
print(np.allclose(einsum_elementwise_mul_Xy, ref_elementwise_mul_Xy), einsum_elementwise_mul_Xy.shape, ref_elementwise_mul_Xy.shape)

[[    0     1     4     9    16    25    36    49    64    81   100   121
    144   169   196   225   256   289   324   361]
 [  400   441   484   529   576   625   676   729   784   841   900   961
   1024  1089  1156  1225  1296  1369  1444  1521]
 [ 1600  1681  1764  1849  1936  2025  2116  2209  2304  2401  2500  2601
   2704  2809  2916  3025  3136  3249  3364  3481]
 [ 3600  3721  3844  3969  4096  4225  4356  4489  4624  4761  4900  5041
   5184  5329  5476  5625  5776  5929  6084  6241]
 [ 6400  6561  6724  6889  7056  7225  7396  7569  7744  7921  8100  8281
   8464  8649  8836  9025  9216  9409  9604  9801]
 [10000 10201 10404 10609 10816 11025 11236 11449 11664 11881 12100 12321
  12544 12769 12996 13225 13456 13689 13924 14161]
 [14400 14641 14884 15129 15376 15625 15876 16129 16384 16641 16900 17161
  17424 17689 17956 18225 18496 18769 19044 19321]
 [19600 19881 20164 20449 20736 21025 21316 21609 21904 22201 22500 22801
  23104 23409 23716 24025 24336 24649 24964 25281]


In [52]:
# # 4.2. Per-row sum using ellipsis  
# •  Use einsum to sum each row of X, i.e., shape (10,).  
# •  Use the same code on a 3D array Z = rng.normal(size=(5, 10, 20)) but keep everything except the last axis.  
# ◦  Verify it behaves like np.sum(Z, axis=-1) for both 2D and 3D.
print(X_BD.shape)
ref_per_row_sum_B = np.sum(X_BD,axis=-1)
print(ref_per_row_sum_B)
einsum_per_row_sum_B = np.einsum('...i->...',X_BD)
print(einsum_per_row_sum_B)
print(np.allclose(ref_per_row_sum_B, einsum_per_row_sum_B), einsum_per_row_sum_B.shape, ref_per_row_sum_B.shape)

(10, 20)
[ 190  590  990 1390 1790 2190 2590 2990 3390 3790]
[ 190  590  990 1390 1790 2190 2590 2990 3390 3790]
True (10,) (10,)


In [54]:
ex_X_BDL = np.arange(24).reshape(2,3,4)
ex_y_BDL = np.arange(24).reshape(2,3,4)

In [71]:
print(ex_X_BDL)
print('np.sum(ex_X_BDL,axis=-1) \n', np.sum(ex_X_BDL,axis=-1))
einsum_per_row_sum_BD = np.einsum('...i->...',ex_X_BDL)
print(np.allclose(np.sum(ex_X_BDL,axis=-1), einsum_per_row_sum_BD), einsum_per_row_sum_BD.shape, np.sum(ex_X_BDL,axis=-1).shape)
einsum_per_row_sum_BL = np.einsum('...ij->...',ex_X_BDL)
print('einsum_per_row_sum_BL \n', einsum_per_row_sum_BL)
einsum_per_row_sum_BL = np.einsum('...ij->...',ex_X_BDL)
ref_per_row_sum_BL = np.sum(ex_X_BDL, axis=(-1,-2))

print(np.allclose(np.sum(ex_X_BDL,axis=(-1,-2)), einsum_per_row_sum_BL), einsum_per_row_sum_BL.shape, np.sum(ex_X_BDL,axis=(-1,-2)).shape)

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
np.sum(ex_X_BDL,axis=-1) 
 [[ 6 22 38]
 [54 70 86]]
True (2, 3) (2, 3)
einsum_per_row_sum_BL 
 [ 66 210]
True (2,) (2,)


In [62]:
# 4.3. Batched dot with ellipsis 
# Use einsum with ellipsis to compute a per-batch dot product: result shape (7,).  
# •  Compare to np.sum(a * c, axis=1) and to a loop.
ref_num_dot_B = np.sum(ex_X_BDL * ex_y_BDL, axis=(-1,-2))
print(ref_num_dot_B)
print("ex_X_BDL * ex_y_BDL \n",ex_X_BDL * ex_y_BDL)

print("np.einsum('ijk,ijk->ijk',ex_X_BDL,ex_y_BDL) \n", np.einsum('ijk,ijk->ijk',ex_X_BDL,ex_y_BDL))
einsum_num_dot_B = np.einsum('...ij,...ij->...',ex_X_BDL, ex_y_BDL)
print(einsum_num_dot_B)


[ 506 3818]
ex_X_BDL * ex_y_BDL 
 [[[  0   1   4   9]
  [ 16  25  36  49]
  [ 64  81 100 121]]

 [[144 169 196 225]
  [256 289 324 361]
  [400 441 484 529]]]
np.einsum('ijk,ijk->ijk',ex_X_BDL,ex_y_BDL) 
 [[[  0   1   4   9]
  [ 16  25  36  49]
  [ 64  81 100 121]]

 [[144 169 196 225]
  [256 289 324 361]
  [400 441 484 529]]]
[ 506 3818]


In [70]:
# 4.4. Add bias with broadcasting  
# •  Use einsum to add vector b as a bias to each row of X, result shape (10, 20).  
# •  Compare to X + b.
print(X_BD.shape, b_B.shape)
ref_X_b_BD = X_BD + b_B
print(ref_X_b_BD.shape)
# print('X_BD \n',X_BD)
# print('y_BD \n', b_B)
# print('X_BD + b_B \n', X_BD + b_B)

(10, 20) (20,)
(10, 20)


In [88]:
# •  Use einsum to add vector b as a bias to each row of X, result shape (10, 20).  
# Adding bias is not really a typical einsum operation since it involves addition, not just contraction.
# However, we can use np.add with broadcasting, or construct it as: X + b
# For pure einsum with multiple operands and addition, we need elementwise sum:
result_mult = np.einsum('ij,j->ij', X_BD, b_B)  # This multiplies (wrong for addition)
print("Multiply result (wrong):\n", result_mult)

# The correct way to add with einsum requires using multiple outputs or different approach
# Einsum is optimized for contractions, not element-wise addition of broadcasted tensors
# The proper solution for adding bias:
result_add = X_BD + b_B  # Direct addition with broadcasting
print("\nCorrect addition result:\n", result_add)
print(np.allclose(result_add, ref_X_b_BD), result_add.shape, ref_X_b_BD.shape)

Multiply result (wrong):
 [[   0    1    4    9   16   25   36   49   64   81  100  121  144  169
   196  225  256  289  324  361]
 [   0   21   44   69   96  125  156  189  224  261  300  341  384  429
   476  525  576  629  684  741]
 [   0   41   84  129  176  225  276  329  384  441  500  561  624  689
   756  825  896  969 1044 1121]
 [   0   61  124  189  256  325  396  469  544  621  700  781  864  949
  1036 1125 1216 1309 1404 1501]
 [   0   81  164  249  336  425  516  609  704  801  900 1001 1104 1209
  1316 1425 1536 1649 1764 1881]
 [   0  101  204  309  416  525  636  749  864  981 1100 1221 1344 1469
  1596 1725 1856 1989 2124 2261]
 [   0  121  244  369  496  625  756  889 1024 1161 1300 1441 1584 1729
  1876 2025 2176 2329 2484 2641]
 [   0  141  284  429  576  725  876 1029 1184 1341 1500 1661 1824 1989
  2156 2325 2496 2669 2844 3021]
 [   0  161  324  489  656  825  996 1169 1344 1521 1700 1881 2064 2249
  2436 2625 2816 3009 3204 3401]
 [   0  181  364  549  736  9

### Level 5 – Multi‑operand contractions and “real” use cases

In [93]:
A = np.arange(24).reshape(2,3,4)
B = np.arange(40).reshape(2,4,5)
C = np.arange(60).reshape(2,5,6)

In [None]:
# 5.1. Chain of matrix multiplies  
# •  Use one einsum call with three operands A, B, C to compute A @ B @ C (shape (3, 6)).  
# •  Compare to (A @ B) @ C.
ref_ABC_BL = A @ B @ C
print(ref_ABC_BL.shape)
# einsum_ABC_BL = np.einsum('jk,kl,lm->jm',A,B,C) WITHOUT BATCH
einsum_ABC_BL = np.einsum('...jk,...kl,...lm->...jm',A,B,C)
print(einsum_ABC_BL.shape)

(2, 3, 6)
(2, 3, 6)


In [106]:
# 5.2. Bilinear form xᵀ M x  
# •  Use einsum to compute the scalar xᵀ M x.  
# •  Compare to x @ M @ x.
x = np.arange(4).reshape(4,)
y = np.arange(5).reshape(5,)
M = np.arange(16).reshape(4,4)
print(x)
print(y)
print(M)

[0 1 2 3]
[0 1 2 3 4]
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]


In [None]:
ref_xyxt = x @ M @ x 
print(ref_xyxt.shape)
print(ref_xyxt)
print(np.einsum(''))

()
420
