# einsum
```python
a = torch.rand(2, 3)
b = torch.rand(3, 4)
c = torch.einsum('ik,kj->ij', [a, b])
```

- rule1: 在不同输入之间重复出现的索引表示，把输入张量沿着该维度做乘法操作，把`a`和`b`沿着`k`这个**维度相乘**
- rule2: 只出现在`equation`箭头左边的索引，表示中间计算结果需要在这个维度上**求和**，即求和索引
- rule3: equation 箭头右边的索引顺序可以是任意的
- spRule1: equation 中支持`...`省略号，用于表示用户不关心的索引

In [2]:
import torch
import numpy as np

## 1- diagonal

In [3]:
a = torch.arange(9).reshape(3, 3)
torch.einsum('ii->i', [a]), torch.diagonal(a, 0)

(tensor([0, 4, 8]), tensor([0, 4, 8]))

In [4]:
# diagonal sum
torch.einsum('ii', [a]), torch.diagonal(a, 0).sum()

(tensor(12), tensor(12))

## 2- reduce sum

In [7]:
torch.einsum('ij->', [a]), a.sum()

(tensor(36), tensor(36))

## 3- dim sum

In [10]:
torch.einsum('ij->i', [a]), a.sum(dim=1), torch.einsum('ij->j', [a]), a.sum(dim=0)

(tensor([ 3, 12, 21]),
 tensor([ 3, 12, 21]),
 tensor([ 9, 12, 15]),
 tensor([ 9, 12, 15]))

## 4- 矩阵向量乘法

In [12]:
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ij,j->i', [a, b]), a @ b

(tensor([ 5, 14]), tensor([ 5, 14]))

## 5- 矩阵乘法

In [11]:
a = torch.rand(2, 3)
b = torch.rand(3, 4)

torch.einsum('ik,kj->ij', [a, b]), a @ b

(tensor([[0.5603, 0.9049, 0.4620, 0.0556],
         [0.4624, 1.1891, 0.5609, 0.2867]]),
 tensor([[0.5603, 0.9049, 0.4620, 0.0556],
         [0.4624, 1.1891, 0.5609, 0.2867]]))

## 6- 向量内积

In [16]:
a = torch.arange(3)
b = torch.arange(3, 6)
torch.einsum('i,i->', [a, b]), (a * b).sum()

(tensor(14), tensor(14))

## 7- 对应元素相乘reduce sum

In [17]:
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6, 12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b]), (a * b).sum()

(tensor(145), tensor(145))

## 8- 向量外积

In [18]:
a = torch.arange(3)
b = torch.arange(3, 7)  # [3, 4, 5, 6]
torch.einsum('i,j', [a, b]), torch.repeat_interleave(a, len(b)).reshape(len(a), len(b)) * b

(tensor([[ 0,  0,  0,  0],
         [ 3,  4,  5,  6],
         [ 6,  8, 10, 12]]),
 tensor([[ 0,  0,  0,  0],
         [ 3,  4,  5,  6],
         [ 6,  8, 10, 12]]))

## 9- batch 矩阵乘法

In [19]:
a = torch.randn(2, 3, 5)
b = torch.randn(2, 5, 4)
torch.einsum('bik,bkj->bij', [a, b]), a @ b

(tensor([[[ 4.1401,  1.1497,  3.5784, -0.4273],
          [-1.2875,  0.5160,  2.2896,  3.6230],
          [ 2.1326, -0.6485,  1.5664,  1.2608]],
 
         [[ 0.1734,  0.6592, -0.4031,  2.1443],
          [ 0.0825, -1.2430,  2.9188,  1.1550],
          [-1.4477, -2.0411, -0.7977, -1.7201]]]),
 tensor([[[ 4.1401,  1.1497,  3.5784, -0.4273],
          [-1.2875,  0.5160,  2.2896,  3.6230],
          [ 2.1326, -0.6485,  1.5664,  1.2608]],
 
         [[ 0.1734,  0.6592, -0.4031,  2.1443],
          [ 0.0825, -1.2430,  2.9188,  1.1550],
          [-1.4477, -2.0411, -0.7977, -1.7201]]]))