<a href="https://colab.research.google.com/github/trueMars/DiveIntoDeepLearning/blob/main/Conv_MIMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

课件见https://zh-v2.d2l.ai/chapter_convolutional-neural-networks/channels.html

MIMO(多输入多输入通道)

*   每个输入通道对于不同的卷积核
*   每个输出通道可以识别特定模式


1x1卷积层(h==w==1)
不识别空间模式，只是融合通道，相当于输入形状为hw x Ci,权重为Co x Ci的全连接层

In [1]:
import torch
from torch import nn

def corr2d(X, K):
  h, w = K.shape
  Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
  for i in range(Y.shape[0]):
    for j in range(Y.shape[1]):
      Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
  return Y

In [3]:
def corr2d_multi_input(X, K):
  return sum(corr2d(x, k) for x, k in zip(X, K))

In [17]:
X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
               [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])

corr2d_multi_input(X, K)

tensor([[ 56.,  72.],
        [104., 120.]])

In [7]:
X.shape,K.shape


(torch.Size([2, 3, 3]), torch.Size([2, 2, 2]))

In [8]:
def corr2d_multi_input_output(X, K):
  return torch.stack([corr2d_multi_input(X, k) for k in K], 0)

In [18]:
K = torch.stack((K, K + 1, K + 2), 0)
K.shape,K

(torch.Size([3, 2, 2, 2]),
 tensor([[[[0., 1.],
           [2., 3.]],
 
          [[1., 2.],
           [3., 4.]]],
 
 
         [[[1., 2.],
           [3., 4.]],
 
          [[2., 3.],
           [4., 5.]]],
 
 
         [[[2., 3.],
           [4., 5.]],
 
          [[3., 4.],
           [5., 6.]]]]))

In [19]:
ret = corr2d_multi_input_output(X, K)
ret, ret.shape

(tensor([[[ 56.,  72.],
          [104., 120.]],
 
         [[ 76., 100.],
          [148., 172.]],
 
         [[ 96., 128.],
          [192., 224.]]]),
 torch.Size([3, 2, 2]))

In [15]:
def corr2d_multi_input_output_1x1(X, K):
  c_i, h, w = X.shape
  c_o = K.shape[0]
  X = X.reshape((c_i, h * w))
  K = K.reshape((c_o, c_i))
  Y = torch.matmul(K, X)
  return Y.reshape((c_o, h, w))

In [16]:
X = torch.normal(0, 1, (3, 3, 3))
K = torch.normal(0, 1, (2, 3, 1, 1))

Y1 = corr2d_multi_input_output_1x1(X, K)
Y2 = corr2d_multi_input_output(X, K)

assert float(torch.abs(Y1 - Y2).sum()) < 1e-6