In [123]:
import torch
from torch import nn

def corr2d(X, K):  #@save
    """Compute 2D cross-correlation."""
    h, w = K.shape
    Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
    return Y

## 测试 2D 卷积运算

In [124]:
X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
corr2d(X, K)

tensor([[19., 25.],
        [37., 43.]])

## 卷积2D网络

In [125]:
class Conv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()

        # 传入具体的值给 Parameter
        self.weight = nn.Parameter(torch.rand(kernel_size))
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, X):
        return corr2d(X, self.weight)

In [126]:
conv = Conv2D(kernel_size=(2, 2))

X = torch.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
Y = conv(X)

conv.weight, X, Y

(Parameter containing:
 tensor([[0.3308, 0.8742],
         [0.9866, 0.7376]], requires_grad=True),
 tensor([[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]),
 tensor([[2.9292, 2.9292],
         [2.9292, 2.9292]], grad_fn=<CopySlices>))

## 边缘检测应用

In [127]:
X = torch.ones((6, 8))
X[:, 2:6] = 0
X

tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])

In [128]:
K = torch.tensor([[1.0, -1.0]])
Y = corr2d(X, K)
Y

tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])

## 自动学习边缘检测 kernel

In [129]:
conv2d = nn.Conv2d(1, 1, (1, 2), bias=False)

# 批量大小, 通道数, 高, 宽
X = X.reshape(1, 1, 6, 8)
Y = Y.reshape(1, 1, 6, 7)
X, Y

(tensor([[[[1., 1., 0., 0., 0., 0., 1., 1.],
           [1., 1., 0., 0., 0., 0., 1., 1.],
           [1., 1., 0., 0., 0., 0., 1., 1.],
           [1., 1., 0., 0., 0., 0., 1., 1.],
           [1., 1., 0., 0., 0., 0., 1., 1.],
           [1., 1., 0., 0., 0., 0., 1., 1.]]]]),
 tensor([[[[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
           [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
           [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
           [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
           [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
           [ 0.,  1.,  0.,  0.,  0., -1.,  0.]]]]))

In [130]:
for i in range(20):
    Y_hat = conv2d(X)
    l = (Y_hat - Y) ** 2
    conv2d.zero_grad()
    l.sum().backward()

    # 梯度下降
    # 切片操作就是 in-place 操作
    conv2d.weight.data[:] -= 0.01 * conv2d.weight.grad

    print(f'Epochs is {i+1}, loss is {l.sum()}')

Epochs is 1, loss is 9.272306442260742
Epochs is 2, loss is 4.873329162597656
Epochs is 3, loss is 3.4047627449035645
Epochs is 4, loss is 2.577585458755493
Epochs is 5, loss is 1.9866318702697754
Epochs is 6, loss is 1.5369356870651245
Epochs is 7, loss is 1.1899610757827759
Epochs is 8, loss is 0.9214671850204468
Epochs is 9, loss is 0.7135779857635498
Epochs is 10, loss is 0.5525938272476196
Epochs is 11, loss is 0.4279283583164215
Epochs is 12, loss is 0.3313876688480377
Epochs is 13, loss is 0.2566266357898712
Epochs is 14, loss is 0.19873164594173431
Epochs is 15, loss is 0.15389779210090637
Epochs is 16, loss is 0.1191784143447876
Epochs is 17, loss is 0.09229173511266708
Epochs is 18, loss is 0.071470707654953
Epochs is 19, loss is 0.05534692853689194
Epochs is 20, loss is 0.04286065697669983


In [131]:
conv2d.weight

Parameter containing:
tensor([[[[ 0.9474, -0.9474]]]], requires_grad=True)