In chapter_foundations we introduce Einstein summation notation. Skip ahead to see how this works, and then write an implementation of the 1×1 convolution operation using torch.einsum. Compare it to the same operation using torch.conv2d.<br>

Awosome [blog post](https://rockt.github.io/2018/04/30/einsum).



In [23]:
import torch
import torch.nn as nn

In [7]:
x = torch.arange(1,7).view(2,-1)
x.shape

torch.Size([2, 3])

In [10]:
x

tensor([[1, 2, 3],
        [4, 5, 6]])

In [8]:
torch.einsum("ij->",x)

tensor(21)

In [11]:
torch.einsum("ij->i",x), torch.einsum("ij->j",x)

(tensor([ 6, 15]), tensor([5, 7, 9]))

In [19]:
torch.einsum("ij, kj->ik",x, x) == x @ x.T

tensor([[True, True],
        [True, True]])

In [20]:
torch.einsum("ij, ij->ij",x, x)

tensor([[ 1,  4,  9],
        [16, 25, 36]])

In [25]:
img = torch.rand((32, 1, 28, 28))
c = nn.Conv2d(1, 10, kernel_size=1)

In [27]:
c(img).shape

torch.Size([32, 10, 28, 28])

In [31]:
kernel = torch.rand((10, 1, 1, 1))

In [32]:
op = torch.einsum("bchw, oipq->bohw", img, kernel)

In [36]:
kernel[0,0], img[0,0].shape, op.shape

(tensor([[0.0267]]), torch.Size([28, 28]), torch.Size([32, 10, 28, 28]))

In [38]:
comp = (img[0,0] * kernel[0,0] == op[0,0]).float().sum()

In [39]:
comp == 28

tensor(784.)