# 微分可能LUTモデルによるMobileNet風のCIFAR-10学習

pointwise-depthwise-pointwise の組み合わせによる MobileNet 風のネットワークで CIFAR-10 を学習させます 

In [1]:
import numpy as np
from tqdm.notebook import tqdm

import torch
import torchvision
import torchvision.transforms as transforms

import binarybrain as bb

### データセット

データセットの準備には torchvision を使います

In [2]:
# setting
bin_mode  = True
epoch     = 8
net_name  = 'MnistDifferentiableLutMobileNet'
data_path = './data' + net_name
frame_modulation_size = 1

# dataset
dataset_train = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
dataset_test  = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor(), download=True)
loader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=64, shuffle=True, num_workers=2)
loader_test  = torch.utils.data.DataLoader(dataset=dataset_test,  batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


## ネットワーク構築

Convolution2d を使って畳み込み層を作ります。<br>
Convolution2d は指定した層を im2col と col2im で挟み込んで Lowering による畳み込みをサポートします。<br>
DenseAffine を Lowering すると一般にCNNで知られる畳み込み層になりますが、LUT-Network では
ここに DifferentiableLut を組み合わせて作った層を設定することでDenseAffineとは異なる効率の良い畳み込み層を実現します。

In [3]:
# バイナリ時は BIT型を使えばメモリ削減可能
bin_dtype = bb.DType.BIT if bin_mode else bb.DType.FP32

def make_conv_layer(hidden_ch, output_ch, padding='same', bin_dtype=bb.DType.BIT):
    return bb.Sequential([
                # input(pointwise)
                bb.Convolution2d(
                    bb.Sequential([
                        bb.DifferentiableLut([hidden_ch*6, 1, 1], bin_dtype=bin_dtype),
                        bb.DifferentiableLut([hidden_ch,   1, 1], connection='serial', bin_dtype=bin_dtype),
                    ]),
                    filter_size=(1, 1),
                    fw_dtype=bin_dtype),

                # hidden(depthwise)
                bb.Convolution2d(
                    bb.Sequential([
                        bb.DifferentiableLut([hidden_ch, 1, 6], connection='depthwise', bin_dtype=bin_dtype),
                        bb.DifferentiableLut([hidden_ch, 1, 1], connection='depthwise', bin_dtype=bin_dtype),
                    ]),
                    filter_size=(3, 3), padding=padding,
                    fw_dtype=bin_dtype),
                
                # output(pointwise)
                bb.Convolution2d(
                    bb.Sequential([
                        bb.DifferentiableLut([output_ch*6, 1, 1], bin_dtype=bin_dtype),
                        bb.DifferentiableLut([output_ch,   1, 1], connection='serial', bin_dtype=bin_dtype),
                    ]),
                    filter_size=(1, 1),
                    fw_dtype=bin_dtype),           ])

# define network
net = bb.Sequential([
            bb.RealToBinary(frame_modulation_size=frame_modulation_size, depth_modulation_size=15, bin_dtype=bin_dtype),
            bb.Sequential([
#                make_conv_layer(72, 36, padding='same', bin_dtype=bin_dtype),  # 32x32
                make_conv_layer(72, 36, padding='valid', bin_dtype=bin_dtype),  # 32x32-> 28x28
                make_conv_layer(72, 72, padding='valid', bin_dtype=bin_dtype),  # 30x30-> 28x28
                bb.MaxPooling(filter_size=(2, 2), fw_dtype=bin_dtype),  # 28x28-> 14x14
            ]),
            bb.Sequential([
#                make_conv_layer(144, 72, padding='same', bin_dtype=bin_dtype),   # 14x14
                make_conv_layer(144, 72, padding='valid', bin_dtype=bin_dtype),  # 14x14-> 12x12
                make_conv_layer(144, 144, padding='valid', bin_dtype=bin_dtype),  # 12x12-> 10x10
                bb.MaxPooling(filter_size=(2, 2), fw_dtype=bin_dtype),  # 10x10-> 5x5
            ]),
            bb.Sequential([
                bb.Convolution2d(  # 5x5-> 1x1
                    bb.Sequential([
                        bb.DifferentiableLut([512*6],  bin_dtype=bin_dtype),
                        bb.DifferentiableLut([512],    bin_dtype=bin_dtype),
                        bb.DifferentiableLut([10*6*6], bin_dtype=bin_dtype),
                        bb.DifferentiableLut([10*6],   bin_dtype=bin_dtype),
                        bb.DifferentiableLut([10],     bin_dtype=bin_dtype),
                    ]),
                    filter_size=(5, 5),
                    fw_dtype=bin_dtype),
            ]),
            bb.BinaryToReal(frame_modulation_size=frame_modulation_size, bin_dtype=bin_dtype)
        ])

net.set_input_shape([3, 32, 32])

if bin_mode:
    net.send_command("binary true")

#print(net.get_info()) # ネットを表示

## 学習実施

学習を行います

In [4]:
# bb.load_networks(data_path, net)

# learning
loss      = bb.LossSoftmaxCrossEntropy()
metrics   = bb.MetricsCategoricalAccuracy()
optimizer = bb.OptimizerAdam()

optimizer.set_variables(net.get_parameters(), net.get_gradients())

for epoch in range(epoch):
    loss.clear()
    metrics.clear()

    # learning
    with tqdm(loader_train) as t:
        for images, labels in t:
            x_buf = bb.FrameBuffer.from_numpy(np.array(images).astype(np.float32))
            t_buf = bb.FrameBuffer.from_numpy(np.identity(10)[np.array(labels)].astype(np.float32))

            y_buf = net.forward(x_buf, train=True)

            dy_buf = loss.calculate(y_buf, t_buf)
            metrics.calculate(y_buf, t_buf)
            net.backward(dy_buf)

            optimizer.update()

            t.set_postfix(loss=loss.get(), acc=metrics.get())

    # test
    loss.clear()
    metrics.clear()
    for images, labels in loader_test:
        x_buf = bb.FrameBuffer.from_numpy(np.array(images).astype(np.float32))
        t_buf = bb.FrameBuffer.from_numpy(np.identity(10)[np.array(labels)].astype(np.float32))

        y_buf = net.forward(x_buf, train=False)

        loss.calculate(y_buf, t_buf)
        metrics.calculate(y_buf, t_buf)

    bb.save_networks(data_path, net)

    print('epoch[%d] : loss=%f accuracy=%f' % (epoch, loss.get(), metrics.get()))

HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch[0] : loss=2.063036 accuracy=0.193300


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch[1] : loss=1.971270 accuracy=0.222400


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch[2] : loss=1.955115 accuracy=0.223700


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch[3] : loss=2.014500 accuracy=0.235600


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch[4] : loss=1.916100 accuracy=0.251400


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch[5] : loss=1.929415 accuracy=0.275800


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch[6] : loss=1.936344 accuracy=0.215700


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))




KeyboardInterrupt: 

## RTL(Verilog)変換

FPGA化するために Verilog に変換します。インターフェースはXilinx社のAXI4 Stream Video 仕様(フレームスタートでtuserが立つ)となります。
MaxPooling の単位で画像サイズが縮小されてしまうので、現状、この単位でしか変換できないため3つに分けて出力しています。

In [None]:
# export verilog
with open(net_name+'.v', 'w') as f:
    f.write('`timescale 1ns / 1ps\n\n')
    f.write(bb.make_verilog_lut_cnv_layers(net_name + 'Cnv0', net[1]))
    f.write(bb.make_verilog_lut_cnv_layers(net_name + 'Cnv1', net[2]))
    f.write(bb.make_verilog_lut_cnv_layers(net_name + 'Cnv2', net[3]))
    