# Multiple Input and Multiple Output Channels

## 1. Multiple Input Channels
When the input data contains multiple channels, we need to construct a kernel with the same number if input channels as the input data.

In [1]:
import torch
from d2l import torch as d2l

def corr2d_multi_in(X, K):
    # Iterate through the 0th channel of K first and add then up
    return sum(d2l.corr2d(x, k) for x, k in zip(X, K))

In [2]:
X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
               [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])

corr2d_multi_in(X, K)

tensor([[ 56.,  72.],
        [104., 120.]])

## 2. Multiple Output Channels
So far we always ended up with one output channel. However, it turns out to be essential to have multiple channels at each layer. In the most popular neural network architectures, we actually increase the channel dimension as we go deeper in the neural network, typically downsampling to trade off spatial resolution for greater channel.

Intuitively, you could think of each channel as responding to a different set of features. For example, in edge detector, different channels may represent edges in different directions.

In [4]:
def corr2d_multi_in_out(X, K):
    return torch.stack([corr2d_multi_in(X, k) for k in K], 0)

In [5]:
K = torch.stack((K, K+1, K+2), 0)
K.shape

torch.Size([3, 2, 2, 2])

In [6]:
corr2d_multi_in_out(X, K)

tensor([[[ 56.,  72.],
         [104., 120.]],

        [[ 76., 100.],
         [148., 172.]],

        [[ 96., 128.],
         [192., 224.]]])

## 3.   1 $\times$ 1 Convolutional Layer

The 1 $\times$ 1 convolution kernels lose the ability to recognize patterns consisting of interactions among adjacent elements in the height and width dimensions. The only computaion occurs on the channel dimension.

Note that convolutional layers are typically followed by nonlinearities. This ensures that 1 $\times$ 1 convolutions cannot simply be folded into other convolutions.

In [7]:
def corr2d_multi_in_out_1x1(X, K):
    c_i, h, w = X.shape
    c_o = K.shape[0] 
    X = X.reshape((c_i, h*w))
    K = K.reshape((c_o, c_i))
    Y = torch.matmul(K, X)
    return Y.reshape((c_o, h, w))

In [9]:
X = torch.normal(0, 1, (3, 3, 3))
K = torch.normal(0, 1, (2, 3, 1, 1))

Y1 = corr2d_multi_in_out_1x1(X, K)
Y2 = corr2d_multi_in_out(X, K)
assert float(torch.abs(Y1 - Y2).sum()) < 1e-6, "Fail"