In [6]:
# in channel and out channel
import torch
from d2l import torch as d2l

# multiple in-channel
# -------------------
def multi_in_corr2d(X, K):
    return sum(d2l.corr2d(x, k) for x, k in zip(X, K))  # traverse input X and kernel K and 
                                                        # compute cross-correlation of every subset

X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
               [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])

print(multi_in_corr2d(X, K))

tensor([[ 56.,  72.],
        [104., 120.]])


In [7]:
# multi out-channel

def multi_in_out_corr2d(X, K):
    return torch.stack([multi_in_corr2d(X, k) for k in K], 0)

K = torch.stack((K, K + 1, K + 2), 0)    # 3 out-channel
print(K.shape)                           # the shapes of K are [out-channel, in-channel, height(n_rows), width(n_columns)]
print(multi_in_out_corr2d(X, K))

torch.Size([3, 2, 2, 2])
tensor([[[ 56.,  72.],
         [104., 120.]],

        [[ 76., 100.],
         [148., 172.]],

        [[ 96., 128.],
         [192., 224.]]])


In [9]:
# 1x1 kernel convolution

def multi_in_out_1x1_corr2d(X, K):
    c_in, h, w = X.shape
    c_out = K.shape[0]
    X = X.reshape((c_in, h * w))
    K = K.reshape((c_out, c_in))
    y = torch.matmul(K, X)
    return y.reshape((c_out, h, w))

X = torch.normal(0, 1, (3, 3, 3))     # 3 in-channel, height, width
K = torch.normal(0, 1, (2, 3, 1, 1))  # 2 out-channel, 3 in-channel, height, width 
Y1 = multi_in_out_1x1_corr2d(X, K)
Y2 = multi_in_out_corr2d(X, K)
print(Y1)
print(Y1)
print(Y1 == Y2)

tensor([[[-0.0356, -0.7028, -0.1476],
         [-0.4722, -0.4097, -0.6582],
         [ 1.0104,  0.8743,  0.5525]],

        [[ 0.0515,  2.2388, -2.8437],
         [ 0.7046,  6.6494, -4.0016],
         [ 0.2126, -1.1708, -1.7709]]])
tensor([[[-0.0356, -0.7028, -0.1476],
         [-0.4722, -0.4097, -0.6582],
         [ 1.0104,  0.8743,  0.5525]],

        [[ 0.0515,  2.2388, -2.8437],
         [ 0.7046,  6.6494, -4.0016],
         [ 0.2126, -1.1708, -1.7709]]])
tensor([[[True, True, True],
         [True, True, True],
         [True, True, True]],

        [[True, True, True],
         [True, True, True],
         [True, True, True]]])
