In [42]:
!pip install d2l




In [43]:
import torch
from torch import nn
from d2l import torch as d2l


In [44]:
def corr2d(X, kernel):
  '''
  calculate correlation
  similar to convoluiton
  padding = 0
  slide = 1
  '''
  h, w = kernel.shape
  Y = torch.zeros((X.shape[0] -  h +1, X.shape[1] - w +1))     # output shape
  for i in range(Y.shape[0]):
    for j in range(Y.shape[1]):
      Y[i, j] = (X[i:i+h, j:j+w] * kernel).sum()    # element multiplication
  return Y

In [45]:
# teset coorr2d
X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
kernel = torch.tensor(([[0.0,1.0],[2.0, 3.0]]))
corr2d(X, kernel)


tensor([[19., 25.],
        [37., 43.]])

## Two dimensional convolution layer

In [46]:
class Conv2D(nn.Module):
  def __init__(self, kernel_size):
    super().__init__()
    self.weight = nn.Parameter(torch.rand(kernel_size))
    self.bias = nn.Parameter(torch.zeros(1))
  def forward(self, x):
    return corr2d(x, self.weight) + self.bias

## Use covolutional layer to detect edge

In [47]:
X = torch.ones((6, 8))
X[:, 2:6] = 0
X

tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])

In [48]:
kernel = torch.tensor([[1.0, -1.0]])   # kernel
Y = corr2d(X, kernel)
Y

tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])

## Learn the kernel/weight from X to Y

In [49]:
conv2d = nn.Conv2d(1,1, kernel_size = (1, 2), bias=False)
X = X.reshape((1, 1, 6, 8))
Y = Y.reshape((1, 1, 6, 7))
for i in range(10):
  Y_hat = conv2d(X)
  l = (Y_hat - Y)**2
  conv2d.zero_grad()
  l.sum().backward()
  conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad
  if (i+1) % 2 == 0:
    print(f'batch {i+1}, loss {l.sum():.3f}')


batch 2, loss 13.642
batch 4, loss 4.591
batch 6, loss 1.713
batch 8, loss 0.674
batch 10, loss 0.271


In [50]:
conv2d.weight.data.reshape((1,2))


tensor([[ 0.9360, -1.0429]])

## Multiple channels in and out

In [51]:
import torch
from d2l import torch as d2l

In [65]:
# multiple channel input, one channel output
def corr2d_multi_in(X, K):
  return sum(d2l.corr2d(x, k) for x, k in zip(X, K))

In [59]:
X = torch.tensor([[[0.0, 1.0, 2.0],[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
                  [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
print('X shape: ', X.shape)

X shape:  torch.Size([2, 3, 3])


In [54]:
K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])
print('K shape: ', K.shape, '\n', K)

K shape:  torch.Size([2, 2, 2]) 
 tensor([[[0., 1.],
         [2., 3.]],

        [[1., 2.],
         [3., 4.]]])


In [55]:
# multiple channel input, one channel output
out = corr2d_multi_in(X, K)
print(out.shape, '\n', out)


torch.Size([2, 2]) 
 tensor([[ 56.,  72.],
        [104., 120.]])


In [56]:
# multiple channel input, multiple channel output
def corr2d_multi_in_out(X, k):
  return torch.stack([corr2d_multi_in(X, k) for k in K], 0)

In [57]:
K = torch.stack((K, K+1, K+2), 0)
print(K.shape, '\n', K)


torch.Size([3, 2, 2, 2]) 
 tensor([[[[0., 1.],
          [2., 3.]],

         [[1., 2.],
          [3., 4.]]],


        [[[1., 2.],
          [3., 4.]],

         [[2., 3.],
          [4., 5.]]],


        [[[2., 3.],
          [4., 5.]],

         [[3., 4.],
          [5., 6.]]]])


In [67]:
print(X.shape)
print(K.shape)

out = corr2d_multi_in_out(X, K)
print(out.shape, '\n', out)


torch.Size([2, 3, 3])
torch.Size([3, 2, 2, 2])
torch.Size([3, 2, 2]) 
 tensor([[[ 56.,  72.],
         [104., 120.]],

        [[ 76., 100.],
         [148., 172.]],

        [[ 96., 128.],
         [192., 224.]]])


## 1x1 convoluton equal to fully connected layer

In [74]:
def corr2d_multi_in_out_1x1(X, K):
  c_i, h, w = X.shape
  c_o = K.shape[0]
  X = X.reshape((c_i, h*w))
  K = K.reshape((c_o, c_i))
  Y = torch.matmul(K, X)
  return Y.reshape((c_o, h, w))


X = torch.normal(0, 1, (3, 3, 3))
# K is 1x1 kernel 
K = torch.normal(0, 1, (2, 3, 1, 1))

Y1 = corr2d_multi_in_out_1x1(X, K)
Y2 = corr2d_multi_in_out(X, K)
assert float(torch.abs(Y1- Y2).sum()) < 1e-6
