For matrix multiplication `ik,kj->ij`:
$$
\sum_k A_{i,k} B_{k,j} \stackrel{\text{assigned to}}{\rightarrow} M_{i,j}
$$

* The missing $k$ index on the RHS is going to be summed over,
* and the free indices $i$ and $j$ are assigned according to relative positions.
* And comma "," separates operands if there are multiple. 

In [22]:
import torch

## Simple Sum

In [57]:
x = torch.randn(3)
x

tensor([-0.4001, -1.0764,  1.2867])

In [58]:
torch.einsum("i->", x), x.sum()

(tensor(-0.1899), tensor(-0.1899))

## Transpose

In [45]:
A = torch.randn(3, 3)
A

tensor([[ 0.1509, -0.0523, -0.7422],
        [ 0.4386, -0.3916,  1.2186],
        [-0.3641,  0.4673, -1.1790]])

In [46]:
torch.einsum("ij->ji", A)

tensor([[ 0.1509,  0.4386, -0.3641],
        [-0.0523, -0.3916,  0.4673],
        [-0.7422,  1.2186, -1.1790]])

## Row and Column Sum

In [48]:
torch.einsum("ij->i", A), A.sum(-1)

(tensor([-0.6436,  1.2655, -1.0759]), tensor([-0.6436,  1.2655, -1.0759]))

In [49]:
torch.einsum("ij->j", A), A.sum(0)

(tensor([ 0.2253,  0.0234, -0.7026]), tensor([ 0.2253,  0.0234, -0.7026]))

## Trace

In [36]:
torch.einsum("ii->", A), A.diag().sum()

(tensor(-0.8049), tensor(-0.8049))

## Permute Dimensions

In [54]:
A = torch.randn(1, 2, 3)
A

tensor([[[-0.3725, -1.1816,  0.4561],
         [ 1.2399,  0.4473,  0.9907]]])

In [55]:
torch.einsum('ijk->kji', A)

tensor([[[-0.3725],
         [ 1.2399]],

        [[-1.1816],
         [ 0.4473]],

        [[ 0.4561],
         [ 0.9907]]])

In [56]:
A.permute(2, 1, 0)

tensor([[[-0.3725],
         [ 1.2399]],

        [[-1.1816],
         [ 0.4473]],

        [[ 0.4561],
         [ 0.9907]]])

## Inner and Outter Products

In [37]:
a = torch.randn(3)
b = torch.randn(3)
a, b

(tensor([ 1.3860,  1.0339, -0.5946]), tensor([ 0.1046, -0.1353, -1.8648]))

In [38]:
torch.einsum("i,i->", a, b), (a * b).sum()

(tensor(1.1138), tensor(1.1138))

In [43]:
torch.einsum("i,j->ij", a, b), (a.unsqueeze(-1) @ b.unsqueeze(0))

(tensor([[ 0.1449, -0.1876, -2.5845],
         [ 0.1081, -0.1399, -1.9279],
         [-0.0622,  0.0805,  1.1087]]),
 tensor([[ 0.1449, -0.1876, -2.5845],
         [ 0.1081, -0.1399, -1.9279],
         [-0.0622,  0.0805,  1.1087]]))

## MatMul (Simple and Batched)

In [39]:
A = torch.randn(2, 3)
B = torch.randn(3, 4)

In [40]:
torch.allclose(A @ B, torch.einsum("ij,jk->ik", A, B))

True

In [41]:
bs = 5
A = torch.randn(bs, 2, 3)
B = torch.randn(bs, 3, 4)
torch.allclose(A @ B, torch.einsum("bij,bjk->bik", A, B))

True

## Matrix-Vector Product

In [52]:
A = torch.randn(3, 3)
x = torch.randn(3)
torch.einsum("ij,j->i", A, x), A @ x

(tensor([ 0.2009, -0.0788,  0.2560]), tensor([ 0.2009, -0.0788,  0.2560]))