In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [8]:
def dilation_demo():
    a = torch.randn(7,7)
    print(a)
    # dilation = 1
    print(a[0:3, 0:3])
    # dilation = 2
    print(a[0:5:2, 0:5:2])
    # dilation = 3
    dilation = 3
    print(a[0:7:3, 0:7:3])

dilation_demo()

tensor([[ 0.6998, -1.8585,  1.6903,  0.3141, -0.3566,  0.8918,  0.4699],
        [ 1.3772, -0.1361, -1.2222,  0.1592, -0.0675, -0.9032, -0.8027],
        [ 0.8802, -2.5107, -0.8699, -1.6091,  0.4236, -1.6767, -0.7457],
        [ 0.1699,  0.3757, -0.2792,  0.3602, -0.1941,  0.8076, -0.6239],
        [ 0.9251, -0.7313, -1.3957, -0.2472, -0.7104,  0.5917, -0.0107],
        [-1.5601,  0.6924,  0.3242, -3.0884,  0.8980, -1.0992,  1.1466],
        [ 0.2344,  0.7881,  0.7016,  0.1453, -0.9354,  1.3735, -0.5850]])
tensor([[ 0.6998, -1.8585,  1.6903],
        [ 1.3772, -0.1361, -1.2222],
        [ 0.8802, -2.5107, -0.8699]])
tensor([[ 0.6998,  1.6903, -0.3566],
        [ 0.8802, -0.8699,  0.4236],
        [ 0.9251, -1.3957, -0.7104]])
tensor([[ 0.6998,  0.3141,  0.4699],
        [ 0.1699,  0.3602, -0.6239],
        [ 0.2344,  0.1453, -0.5850]])


In [4]:
def matrix_multiplication_for_conv2d_final(input, kernel, bias=None, stride=1,
                                           padding=0, dilation=1, groups=1):
    if padding > 0:
        input = F.pad(input, (padding, padding, padding, padding, 0,0,0,0))

    # batch_size, in_channel, input h, input w
    bs, ic, ih, iw = input.shape
    # out_channel, _, kernel h, kernel w
    oc, _ic, kh, kw = kernel.shape
    if bias is None:
        bias = torch.zeros(oc)
    # 考虑groups情况, 确保ic,oc能被groups整除
    assert oc % groups == 0 and ic % groups == 0, "groups必须同时被通道数整除！"
    # reshape一下，把groups拆开
    input = input.reshape((bs, groups, ic//groups, ih, iw))
    kernel = kernel.reshape((groups, oc//groups, ic//groups, kh, kw))
    # 相邻点之间插入dilation-1个空洞，插入kh - 1次，所以增加的距离一共是
    kh = (dilation - 1) * (kh - 1) + kh
    kw = (dilation - 1) * (kw - 1) + kw
    # 输出 高度和宽度, 不需要考虑 dilation了，因为已经在kh和kw里面了
    oh = int(math.floor((ih - kh)/stride)) + 1
    ow = int(math.floor((iw - kw)/stride)) + 1
    output_shape = (bs, groups, oc//groups, oh, ow)
    # 初始化输出
    output = torch.zeros(output_shape)
    # 遍历计算
    for ind in  range(bs): # batch遍历
        for g in range(groups): # 群组遍历
            for oc_ind in range(oc//groups): # 对分组的输出通道遍历
                for ic_ind in range(ic//groups): # 对分组的输入通道遍历
                    for i in range(0, ih-kh+1, stride): # 高度
                        for j in range(0, iw-kw+1, stride): # 宽度
                            # 取出区域
                            region = input[ind, g, ic_ind, i:i+kh:dilation,         j:j+kw:dilation]
                            output[ind, g, oc_ind, int(i/stride), int(j/stride)] += torch.sum(region * kernel[g, oc_ind, ic_ind])
                # bias偏置，计算走过多少个通道
                output[ind, g, oc_ind] += bias[g*(oc//groups) + oc_ind]
    # 还原回4维
    output = output.reshape((bs, oc, oh, ow))

    return output


def test_conv2d_final():
    bs, ic, ih, iw = 2, 2, 5, 5
    kh, kw = 3, 3
    oc = 4
    groups, dilation, stride = 2, 2, 2
    padding = 1


    input = torch.randn(bs, ic, ih, iw)
    # groups大于1，kernel数量会减小，输入通道数减小
    kernel = torch.randn(oc, ic//groups, kh, kw)
    bias = torch.randn(oc)

    py_res = F.conv2d(input, kernel, bias=bias, padding=padding, stride=stride,
                    dilation=dilation, groups=groups)

    my_res = matrix_multiplication_for_conv2d_final(
        input, kernel, bias=bias,padding=padding, stride=stride,
        dilation=dilation, groups=groups)

    flag = torch.allclose(py_res, my_res)
    print(flag)

test_conv2d_final()

True
