## 7.4　Convoliution／Pooling レイヤの実装

In [9]:
import os
import sys
import numpy as np

In [10]:
os.chdir('/Users/yuta.shimizu/Downloads/ML/deep-learning-from-scratch-master/ch01')
sys.path.append(os.pardir)

In [11]:
from common.util import im2col

### 7.4.3　Convolution レイヤの実装

In [12]:
x1 = np.random.rand(1, 3, 7, 7)
col1 = im2col(x1, 5, 5, stride=1, pad=0)
print(col1.shape)

x2 = np.random.rand(10, 3, 7, 7)
col2 = im2col(x2, 5, 5, stride=1, pad=0)
print(col2.shape)

(9, 75)
(90, 75)


In [17]:
class Convolution:
    def __init__(self, W, b, stride=1, pad=0):
        self.W = W
        self.b = b
        self.stride = stride
        self.pad = pad

In [19]:
def forward(self, x):
    FN, C, FH, FW = self.W.shape
    N, C, H, W = x.shape
    out_h = int(1 + (H + 2*self.pad - FH) / self.stride)
    out_w = int(1 + (W + 2*self.pad - FW) / self.stride)
    
    col = im2col(x, FH, FW, self.stride, self.pad)
    col_W = self.W.reshape(FN, -1).T
    out = np.dot(col, col_W) + self.b
    
    out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)
    
    return out

`reshape` の際に `-1` を指定することで、多次元配列の要素の辻褄が合うように、要素数をまとめてくれる。

In [22]:
x = np.random.rand(10, 3, 5, 5)
print(x.shape)
x = x.reshape(10, -1)
print(x.shape)

(10, 3, 5, 5)
(10, 75)


`tranpose` は、軸の順番を変更する。

In [25]:
x = np.random.rand(10, 3, 5, 5)
print(x.shape)
x = x.transpose(3, 1, 2, 0)
print(x.shape)

(10, 3, 5, 5)
(5, 3, 5, 10)


### 7.4.4　Pooling レイヤの実装

In [27]:
class Pooling:
    def __init__(self, pool_h, pool_w, stride=2, pad=0):
        self.pool_h = pool_h
        self.pool_w = pool_w
        self.stride = stride
        self.pad = pad
        
    def forward(self, x):
        N, C, H, W = x.shape
        out_h = int(1 + (H - self.pool_h) / self.stride)
        out_w = int(1 + (W - self.pool_w) / self.stride)()
        
        col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
        col = col.reshape(-1, self.pool_h*self.pool_w)
        
        out = np.max(col, axis=1)
        out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)
        
        return out