# 微分可能LUTモデルによるCNNでのMNIST学習

Differentiable LUTモデルで畳み込み層を形成して、一般的なデータに対してCNNによる回路学習を行います。 

In [1]:
import os
import shutil
import numpy as np
from tqdm import tqdm
#from tqdm.notebook import tqdm

import torch
import torchvision
import torchvision.transforms as transforms

import binarybrain as bb

### データセット

データセットの準備には torchvision を使います

In [2]:
# configuration
net_name              = 'MnistDifferentiableLutCnn'
data_path             = os.path.join('./data/', net_name)
rtl_sim_path          = '../../verilog/mnist/tb_mnist_lut_cnn'
rtl_module_name       = 'MnistLutCnn'
output_velilog_file   = os.path.join(data_path, rtl_module_name + '.v')
sim_velilog_file      = os.path.join(rtl_sim_path, rtl_module_name + '.v')

bin_mode              = True
frame_modulation_size = 15
epochs                = 64
mini_batch_size       = 64


# dataset
dataset_path = './data/'
dataset_train = torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transforms.ToTensor(), download=True)
dataset_test  = torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transforms.ToTensor(), download=True)
loader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=mini_batch_size, shuffle=True, num_workers=2)
loader_test  = torch.utils.data.DataLoader(dataset=dataset_test,  batch_size=mini_batch_size, shuffle=False, num_workers=2)

## ネットワーク構築

Convolution2d を使って畳み込み層を作ります。<br>
Convolution2d は指定した層を im2col と col2im で挟み込んで Lowering による畳み込みをサポートします。<br>
DenseAffine を Lowering すると一般にCNNで知られる畳み込み層になりますが、LUT-Network では
ここに DifferentiableLut を組み合わせて作った層を設定することでDenseAffineとは異なる効率の良い畳み込み層を実現します。

In [3]:
# バイナリ時は BIT型を使えばメモリ削減可能
bin_dtype = bb.DType.BIT if bin_mode else bb.DType.FP32

# define network
net = bb.Sequential([
            bb.RealToBinary(frame_modulation_size=frame_modulation_size, bin_dtype=bin_dtype),
            bb.Sequential([
                bb.Convolution2d(
                    bb.Sequential([
                        bb.DifferentiableLut([36*6], connection='random', bin_dtype=bin_dtype),
                        bb.DifferentiableLut([36], connection='serial', bin_dtype=bin_dtype),
                    ]),
                    filter_size=(3, 3),
                    fw_dtype=bin_dtype),
                bb.Convolution2d(
                    bb.Sequential([
                        bb.DifferentiableLut([2*36*6], connection='random', bin_dtype=bin_dtype),
                        bb.DifferentiableLut([2*36], connection='serial', bin_dtype=bin_dtype),
                    ]),
                    filter_size=(3, 3),
                    fw_dtype=bin_dtype),
                bb.MaxPooling(filter_size=(2, 2), fw_dtype=bin_dtype),
            ]),
            bb.Sequential([
                bb.Convolution2d(
                    bb.Sequential([
                        bb.DifferentiableLut([2*36*6], connection='random', bin_dtype=bin_dtype),
                        bb.DifferentiableLut([2*36], connection='serial', bin_dtype=bin_dtype),
                    ]),
                    filter_size=(3, 3),
                    fw_dtype=bin_dtype),
                bb.Convolution2d(
                    bb.Sequential([
                        bb.DifferentiableLut([4*36*6], connection='random', bin_dtype=bin_dtype),
                        bb.DifferentiableLut([4*36], connection='serial', bin_dtype=bin_dtype),
                    ]),
                    filter_size=(3, 3),
                    fw_dtype=bin_dtype),
                bb.MaxPooling(filter_size=(2, 2), fw_dtype=bin_dtype),
            ]),
            bb.Sequential([
                bb.Convolution2d(
                    bb.Sequential([
                        bb.DifferentiableLut([6*256], connection='random', bin_dtype=bin_dtype),
                        bb.DifferentiableLut([256], connection='serial', bin_dtype=bin_dtype),
                        
                        bb.DifferentiableLut([6*6*6*10], connection='random', bin_dtype=bin_dtype),
                        bb.DifferentiableLut([6*6*10], connection='serial', bin_dtype=bin_dtype),
                        bb.DifferentiableLut([6*10], connection='serial', bin_dtype=bin_dtype),
                        bb.AverageLut([10], connection='serial', bin_dtype=bin_dtype),
                    ]),
                    filter_size=(4, 4),
                    fw_dtype=bin_dtype),
            ]),
            bb.BinaryToReal(frame_integration_size=frame_modulation_size, bin_dtype=bin_dtype)
        ])

net.set_input_shape([1, 28, 28])

if bin_mode:
    net.send_command("binary true")

#   print(net.get_info())

## 学習実施

学習を行います

In [4]:
# 前の学習結果があれば読み込む
bb.load_networks(data_path, net)

# learning
loss      = bb.LossSoftmaxCrossEntropy()
metrics   = bb.MetricsCategoricalAccuracy()
optimizer = bb.OptimizerAdam(learning_rate=0.0001)

optimizer.set_variables(net.get_parameters(), net.get_gradients())

for epoch in range(epochs):
    loss.clear()
    metrics.clear()

    # learning
    with tqdm(loader_train) as t:
        for images, labels in t:
            x_buf = bb.FrameBuffer.from_numpy(np.array(images).astype(np.float32))
            t_buf = bb.FrameBuffer.from_numpy(np.identity(10)[np.array(labels)].astype(np.float32))

            y_buf = net.forward(x_buf, train=True)

            dy_buf = loss.calculate(y_buf, t_buf)
            metrics.calculate(y_buf, t_buf)
            net.backward(dy_buf)

            optimizer.update()

            t.set_postfix(loss=loss.get(), acc=metrics.get())

    # test
    loss.clear()
    metrics.clear()
    for images, labels in loader_test:
        x_buf = bb.FrameBuffer.from_numpy(np.array(images).astype(np.float32))
        t_buf = bb.FrameBuffer.from_numpy(np.identity(10)[np.array(labels)].astype(np.float32))

        y_buf = net.forward(x_buf, train=False)

        loss.calculate(y_buf, t_buf)
        metrics.calculate(y_buf, t_buf)

    bb.save_networks(data_path, net)

    print('epoch[%d] : loss=%f accuracy=%f' % (epoch, loss.get(), metrics.get()))

100%|██████████| 938/938 [12:33<00:00,  1.25it/s, acc=0.981, loss=1.5]


epoch[0] : loss=1.490199 accuracy=0.980700


100%|██████████| 938/938 [12:53<00:00,  1.21it/s, acc=0.981, loss=1.5]


epoch[1] : loss=1.497075 accuracy=0.981900


100%|██████████| 938/938 [12:55<00:00,  1.21it/s, acc=0.981, loss=1.5]


epoch[2] : loss=1.494005 accuracy=0.980900


100%|██████████| 938/938 [12:54<00:00,  1.21it/s, acc=0.982, loss=1.5]


epoch[3] : loss=1.498245 accuracy=0.983500


100%|██████████| 938/938 [12:55<00:00,  1.21it/s, acc=0.982, loss=1.5]


epoch[4] : loss=1.497976 accuracy=0.982400


100%|██████████| 938/938 [12:55<00:00,  1.21it/s, acc=0.982, loss=1.5]


epoch[5] : loss=1.489045 accuracy=0.980500


100%|██████████| 938/938 [12:55<00:00,  1.21it/s, acc=0.982, loss=1.5]


epoch[6] : loss=1.494769 accuracy=0.983300


100%|██████████| 938/938 [12:53<00:00,  1.21it/s, acc=0.983, loss=1.5]


epoch[7] : loss=1.497969 accuracy=0.983800


100%|██████████| 938/938 [12:53<00:00,  1.21it/s, acc=0.982, loss=1.5]


epoch[8] : loss=1.494769 accuracy=0.984000


100%|██████████| 938/938 [12:52<00:00,  1.21it/s, acc=0.983, loss=1.5]


epoch[9] : loss=1.498856 accuracy=0.980000


100%|██████████| 938/938 [12:52<00:00,  1.21it/s, acc=0.982, loss=1.5]


epoch[10] : loss=1.499596 accuracy=0.978700


100%|██████████| 938/938 [12:52<00:00,  1.22it/s, acc=0.982, loss=1.5]


epoch[11] : loss=1.494882 accuracy=0.982900


100%|██████████| 938/938 [12:52<00:00,  1.21it/s, acc=0.982, loss=1.5]


epoch[12] : loss=1.492211 accuracy=0.985100


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.982, loss=1.5]


epoch[13] : loss=1.494151 accuracy=0.981900


100%|██████████| 938/938 [12:52<00:00,  1.21it/s, acc=0.983, loss=1.5]


epoch[14] : loss=1.496747 accuracy=0.984000


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[15] : loss=1.492046 accuracy=0.982900


100%|██████████| 938/938 [12:52<00:00,  1.21it/s, acc=0.983, loss=1.5]


epoch[16] : loss=1.492478 accuracy=0.981400


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[17] : loss=1.490768 accuracy=0.984600


100%|██████████| 938/938 [12:52<00:00,  1.21it/s, acc=0.983, loss=1.5]


epoch[18] : loss=1.494351 accuracy=0.983800


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[19] : loss=1.494625 accuracy=0.983600


100%|██████████| 938/938 [12:52<00:00,  1.21it/s, acc=0.984, loss=1.5]


epoch[20] : loss=1.495061 accuracy=0.981500


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[21] : loss=1.494685 accuracy=0.981600


100%|██████████| 938/938 [12:52<00:00,  1.21it/s, acc=0.983, loss=1.5]


epoch[22] : loss=1.495507 accuracy=0.981400


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[23] : loss=1.495487 accuracy=0.981100


100%|██████████| 938/938 [12:52<00:00,  1.21it/s, acc=0.983, loss=1.5]


epoch[24] : loss=1.495574 accuracy=0.984400


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[25] : loss=1.501202 accuracy=0.983400


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[26] : loss=1.496323 accuracy=0.979500


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[27] : loss=1.497509 accuracy=0.979600


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.982, loss=1.5]


epoch[28] : loss=1.488937 accuracy=0.985500


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[29] : loss=1.499992 accuracy=0.976700


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[30] : loss=1.492426 accuracy=0.980900


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[31] : loss=1.491870 accuracy=0.981800


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[32] : loss=1.499224 accuracy=0.983100


100%|██████████| 938/938 [12:50<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[33] : loss=1.506743 accuracy=0.980300


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.982, loss=1.5]


epoch[34] : loss=1.491235 accuracy=0.983100


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.984, loss=1.5]


epoch[35] : loss=1.493385 accuracy=0.983000


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.984, loss=1.5]


epoch[36] : loss=1.492534 accuracy=0.981400


100%|██████████| 938/938 [12:50<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[37] : loss=1.494963 accuracy=0.979500


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[38] : loss=1.502021 accuracy=0.981600


100%|██████████| 938/938 [12:50<00:00,  1.22it/s, acc=0.983, loss=1.5]


epoch[39] : loss=1.494050 accuracy=0.985300


100%|██████████| 938/938 [12:51<00:00,  1.22it/s, acc=0.984, loss=1.5]


epoch[40] : loss=1.491365 accuracy=0.985400


100%|██████████| 938/938 [12:50<00:00,  1.22it/s, acc=0.983, loss=1.5] 


epoch[41] : loss=1.503267 accuracy=0.976700


 74%|███████▍  | 693/938 [09:30<03:21,  1.21it/s, acc=0.983, loss=1.5]

## RTL(Verilog)変換

FPGA化するために Verilog に変換します。インターフェースはXilinx社のAXI4 Stream Video 仕様(フレームスタートでtuserが立つ)となります。
MaxPooling の単位で画像サイズが縮小されてしまうので、現状、この単位でしか変換できないため3つに分けて出力しています。

In [None]:
# export verilog
with open(output_velilog_file, 'w') as f:
    f.write('`timescale 1ns / 1ps\n\n')
    bb.dump_verilog_lut_cnv_layers(f, rtl_module_name + 'Cnv0', net[1])
    bb.dump_verilog_lut_cnv_layers(f, rtl_module_name + 'Cnv1', net[2])
    bb.dump_verilog_lut_cnv_layers(f, rtl_module_name + 'Cnv2', net[3])

# Simulation用ファイルに上書きコピー
shutil.copyfile(output_velilog_file, sim_velilog_file)

# Simulationで使う画像の生成
def img_geneator():
    for data in dataset_test:
        yield data[0] # 画像とラベルの画像の方を返す

img = (bb.make_image_tile(720//28+1, 1280//28+1, img_geneator())*255).astype(np.uint8)
bb.write_ppm(os.path.join(rtl_sim_path, 'mnist_test_160x120.ppm'), img[:,:120,:160])
bb.write_ppm(os.path.join(rtl_sim_path, 'mnist_test_640x480.ppm'), img[:,:480,:640])
bb.write_ppm(os.path.join(rtl_sim_path, 'mnist_test_1280x720.ppm'), img[:,:720,:1280])