## 爱因斯坦标识(einops)与PyTorch联合编写

In [1]:
import torch
from einops import rearrange, reduce, repeat

### 使用 rearrange 将进行转置和变形

In [2]:
# 1. 转置
x = torch.randn(2, 3, 4, 4)  # 4D tensor  bs * ic * h * w

out1 = x.transpose(1, 2)
out2 = rearrange(x, 'b i h w -> b h i w')
torch.allclose(out1, out2)

True

In [3]:
# 2. 变形
out1 = x.reshape(6, 4, 4)
out2 = rearrange(x, 'b i h w -> (b i) h w')
out3 = rearrange(out2, '(b i) h w -> b i h w', b=2)
print(torch.allclose(out1, out2))
print(torch.allclose(x, out3))

True
True


In [4]:
# 3. image2patch
out1 = rearrange(x, 'b i (h1 p1) (w1 p2) -> b (h1 w1) (p1 p2 i)', p1=2, p2=2)  # p1、p2是patch的高和宽
out1.shape  # [batch_size, num_patch, patch_depth]

torch.Size([2, 4, 12])

In [5]:
# 4. 堆叠张量
tensor_list = [x, x, x]
out1 = rearrange(tensor_list, 'n b i h w -> n b i h w')
out1.shape

torch.Size([3, 2, 3, 4, 4])

### 使用 reduce 进行池化

In [6]:
# 5. 池化
# 求平均池化
out1 = reduce(x, 'b i h w -> b i h', 'mean')  # mean, min, max, sum, prod
print(out1.shape)

# 求和
out2 = reduce(x, 'b i h w -> b i h 1', 'sum')  # keep dimension
# 最大值
out3 = reduce(x, 'b i h w -> b i', 'max')

print(out2.shape)
print(out3.shape)

torch.Size([2, 3, 4])
torch.Size([2, 3, 4, 1])
torch.Size([2, 3])


In [7]:
# 6. 扩维
out1 = rearrange(x, 'b i h w -> b i h w 1')  # 类似torch.unsqueeze
print(out1.shape)

# 7. 复制
out2 = repeat(out1, 'b i h w 1 -> b i h w 2')  # 类似torch.tile
print(out2.shape)

out3 = repeat(x, 'b i h w -> b i (2 h) (2 w)')
print(out3.shape)

torch.Size([2, 3, 4, 4, 1])
torch.Size([2, 3, 4, 4, 2])
torch.Size([2, 3, 8, 8])
