# IIRフィルタ的なことを試してみるテスト

結果を MaxPooling で縮小して次のフレームで使う実験をしてみる

In [None]:
import os
import shutil
import numpy as np
from tqdm import tqdm
from tqdm.notebook import tqdm

import torch
import torchvision
import torchvision.transforms as transforms

import binarybrain as bb

### データセット

データセットの準備には torchvision を使います

In [None]:
# configuration
net_name              = 'MnistIirFilterTestt'
data_path             = os.path.join('./data/', net_name)

bin_mode              = False
frame_modulation_size = 7
epochs                = 4
mini_batch_size       = 64

# dataset
dataset_path = './data/'
dataset_train = torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transforms.ToTensor(), download=True)
dataset_test  = torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transforms.ToTensor(), download=True)
loader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=mini_batch_size, shuffle=True, num_workers=2)
loader_test  = torch.utils.data.DataLoader(dataset=dataset_test,  batch_size=mini_batch_size, shuffle=False, num_workers=2)

## ネットワーク構築

In [None]:
# バイナリ時は BIT型を使えばメモリ削減可能
bin_dtype = bb.DType.BIT if bin_mode else bb.DType.FP32

def create_cnv(output_ch, filter_size=(3, 3), padding='same', fw_dtype=bin_dtype):
    return bb.Convolution2d(
                bb.Sequential([
                    bb.DenseAffine([output_ch, 1, 1]),
                    bb.ReLU(),
                ]),
                filter_size=filter_size,
                padding=padding,
                fw_dtype=fw_dtype)

def create_fc(output_ch, fw_dtype=bin_dtype):
    return bb.Convolution2d(
                bb.Sequential([
                    bb.DenseAffine([output_ch, 1, 1]),
                ]),
                filter_size=(1, 1),
                fw_dtype=fw_dtype)

class MyNetwork(bb.Sequential):
    def __init__(self):
        self.N = 4
#       self.r2b = bb.RealToBinary(frame_modulation_size=frame_modulation_size, bin_dtype=bin_dtype)
#       self.b2r = bb.BinaryToReal(frame_integration_size=frame_modulation_size, bin_dtype=bin_dtype)
        
        self.cnvs0 = bb.Sequential()
        self.cnvs1 = bb.Sequential()
        self.fcs   = bb.Sequential()
        self.upss  = bb.Sequential()
        self.pols  = bb.Sequential()
        for _ in range(self.N):
            self.cnvs0.append(create_cnv(64))
            self.cnvs1.append(create_cnv(64))
            self.fcs.append(create_fc(10))     
            self.upss.append(bb.UpSampling((2, 2)))
            self.pols.append(bb.MaxPooling((2, 2)))
        super(MyNetwork, self).__init__([self.cnvs0, self.cnvs1, self.fcs, self.upss, self.pols])
    
    def set_input_shape(self, shape):
        for i in range(self.N):
            shape1 = self.cnvs0[i].set_input_shape(shape)
        for i in range(self.N):
            shape2 = self.cnvs1[i].set_input_shape(shape1)
        for i in range(self.N):
            self.fcs[i].set_input_shape(shape2)
        for i in range(self.N):
            shape3 = self.pols[i].set_input_shape(shape2)
        for i in range(self.N):
            self.upss[i].set_input_shape(shape3)
    
    def param_copy(self):
        for i in range(1, self.N):
            W = self.cnvs0[i][1][0].W()
            b = self.cnvs0[i][1][0].b()
            W *= 0; W += self.cnvs0[0][1][0].W()
            b *= 0; b += self.cnvs0[0][1][0].b()
            
            W = self.cnvs1[i][1][0].W()
            b = self.cnvs1[i][1][0].b()
            W *= 0; W += self.cnvs1[0][1][0].W()
            b *= 0; b += self.cnvs1[0][1][0].b()
            
            W = self.fcs[i][1][0].W()
            b = self.fcs[i][1][0].b()
            W *= 0; W += self.fcs[0][1][0].W()
            b *= 0; b += self.fcs[0][1][0].b()
            
    def grad_marge(self):
        dW0 = self.cnvs0[0][1][0].dW()
        db0 = self.cnvs0[0][1][0].db()
        dW1 = self.cnvs1[0][1][0].dW()
        db1 = self.cnvs1[0][1][0].db()
        dW2 = self.fcs[0][1][0].dW()
        db2 = self.fcs[0][1][0].db()
        for i in range(1, self.N):
            dW0 += self.cnvs0[i][1][0].dW()
            db0 += self.cnvs0[i][1][0].db()
            dW1 += self.cnvs1[i][1][0].dW()
            db1 += self.cnvs1[i][1][0].db()
            dW2 += self.fcs[i][1][0].dW()
            db2 += self.fcs[i][1][0].db()
    
    def forward(self, x_buf, train=True):
        x = x_buf.numpy()
        px_buf = bb.FrameBuffer.from_numpy(np.zeros((x_buf.get_frame_size(), 64, 14, 14), dtype=np.float32))
        
        y_bufs = []
        for i in range(self.N):
            # 前のフレームの出力と結合
            px_buf = self.upss[i].forward(px_buf, train=train)
            px = px_buf.numpy()
            x_buf = bb.FrameBuffer.from_numpy(np.concatenate((x, px), 1))
            
            # forward
            x_buf = self.cnvs0[i].forward(x_buf, train=train)
            x_buf = self.cnvs1[i].forward(x_buf, train=train)
            
            y_buf = self.fcs[i].forward(x_buf, train=train)
            
            # 出力の1つとして追加
            y_bufs.append(y_buf)
            
            px_buf = self.pols[i].forward(x_buf, train=train)
        
        return y_bufs
    
    def backward(self, dy_bufs):
        pdy_buf = bb.FrameBuffer.from_numpy(np.zeros((dy_bufs[0].get_frame_size(), 64, 14, 14), dtype=np.float32))
        for i in reversed(range(self.N)):
            pdy_buf = self.pols[i].backward(pdy_buf)
            dy_buf = self.fcs[i].backward(dy_bufs[i])
            dx_buf = self.cnvs1[i].backward(dy_buf + pdy_buf)
            dx_buf = self.cnvs0[i].backward(dx_buf)
            dx = dx_buf.numpy()[:,1:]
            pdy_buf = self.upss[i].backward(bb.FrameBuffer.from_numpy(dx))
        
    
net = MyNetwork()
net.set_input_shape([64+1, 28, 28])
net.param_copy()
net.grad_marge()

In [None]:
#bb.load_networks(data_path, net)

In [None]:
losses = []
for _ in range(net.N):
    losses.append(bb.LossSoftmaxCrossEntropy())

optimizer = bb.OptimizerAdam()
metrics   = bb.MetricsCategoricalAccuracy()

parameters = bb.Variables()
parameters.append(net.cnvs0[0].get_parameters())
parameters.append(net.cnvs1[0].get_parameters())
parameters.append(net.fcs[0].get_parameters())
gradients = bb.Variables()
gradients.append(net.cnvs0[0].get_gradients())
gradients.append(net.cnvs1[0].get_gradients())
gradients.append(net.fcs[0].get_gradients())
optimizer.set_variables(parameters, gradients)

epochs = 32
for epoch in range(epochs):
    with tqdm(loader_train) as tq:
        for images, labels in tq:
            x_buf = bb.FrameBuffer.from_numpy(np.array(images).astype(np.float32))
            t = np.zeros((len(labels), 10, 28, 28), dtype=np.float32)
            for i in range(len(labels)):
                t[i][labels[i]][13:15,13:15] += 1  # 中央付近のピクセルだけで評価
            t_buf = bb.FrameBuffer.from_numpy(t)

            net.param_copy()
            y_bufs = net.forward(x_buf, train=True)

            dy_bufs = []
            for i in range(net.N):
                dy_buf = losses[i].calculate(y_bufs[i], t_buf)
                dy_bufs.append(dy_buf)
            
            metrics.calculate(y_bufs[net.N-1], t_buf)
            
            net.backward(dy_bufs)
            net.grad_marge()

            optimizer.update()
                        
            loss = 0
            for i in range(net.N):
                loss += losses[i].get()
                
            tq.set_postfix(loss=loss, metrics=metrics.get())
    bb.save_networks(data_path, net)

In [None]:
---------------------

In [None]:
bb.save_networks(data_path, net)