# MNISTでセグメンテーションに挑戦



In [None]:
import os
import shutil
import random
import numpy as np
import matplotlib.pyplot as plt
#from tqdm.notebook import tqdm
from tqdm import tqdm

import torch
import torchvision
import torchvision.transforms as transforms

import binarybrain as bb

## 初期設定

In [None]:
# configuration
bb.set_device(0)

net_name               = 'MnistSegmentationAndClassificationDistillation'
data_path              = os.path.join('./data/', net_name)
rtl_sim_path           = '../../verilog/mnist'
rtl_module_name        = 'MnistSegmentationAndClassification'
output_velilog_file    = os.path.join(data_path, net_name + '.v')
sim_velilog_file       = os.path.join(rtl_sim_path, rtl_module_name + '.v')

bin_mode               = True
frame_modulation_size  = 3
depth_integration_size = 7
epochs                 = 0
mini_batch_size        = 8

## データセット準備

データセットを自作する
数値が中央に来るピクセル以外も学習させる必要がるため、28x28のMNSIT画像をタイル状に並べて学習データを作る

In [None]:
# dataset
dataset_path = './data/'
dataset_train = torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transforms.ToTensor(), download=True)
dataset_test  = torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transforms.ToTensor(), download=True)

# 並べるタイル数
rows=3
cols=3

# 面積の比率で重みを作っておく
areas = np.zeros((11))
for img, label in dataset_train:
    img = img.numpy()
    areas[label] += np.mean(img)
    areas[10] += np.mean(1.0-img)
areas /= len(dataset_train)

wight = 1 / areas
wight /= np.max(wight)

def make_teacher_image(gen, rows, cols, margin=0):
    source_img  = np.zeros((1, rows*28, cols*28), dtype=np.float32)
    teaching_img = np.zeros((11, rows*28, cols*28), dtype=np.float32)
    for row in range(rows):
        for col in range(cols):
            x = col*28
            y = row*28
            img, label = gen.__next__()
            source_img[0,y:y+28,x:x+28] = img
            teaching_img[label,y:y+28,x:x+28] = img
            teaching_img[10,y:y+28,x:x+28] = (1.0-img)
    teaching_img = (teaching_img > 0.5).astype(np.float32)
    
    # ランダムに反転
    if random.random() > 0.5:
        source_img = 1.0 - source_img
    
    return source_img, teaching_img[:,margin:-margin,margin:-margin]

def transform_data(dataset, n, rows, cols, margin):
    def data_gen():
        l = len(dataset)
        i = 0
        while True:
            yield dataset[i%l]
            i += 1
    
    gen = data_gen()
    source_imgs = []
    teaching_imgs = []
    for _ in range(n):
        x, t = make_teacher_image(gen, rows, cols, margin)
        source_imgs.append(x)
        teaching_imgs.append(t)
    return source_imgs, teaching_imgs

class MyDatasets(torch.utils.data.Dataset):
    def __init__(self, source_imgs, teaching_imgs, transforms=None):
        self.transforms = transforms
        self.source_imgs = source_imgs
        self.teaching_imgs = teaching_imgs
        
    def __len__(self):
        return len(self.source_imgs)

    def __getitem__(self, index):
        source_img = self.source_imgs[index]
        teaching_img = self.teaching_imgs[index]
        if self.transforms:
            source_img, teaching_img = self.transforms(source_img, teaching_img)
        return source_img, teaching_img


source_imgs_train, teaching_imgs_train = transform_data(dataset_train, mini_batch_size*1000, rows, cols, 29)
my_dataset_train = MyDatasets(source_imgs_train, teaching_imgs_train)

source_imgs_test, teaching_imgs_test = transform_data(dataset_test, mini_batch_size*5, rows, cols, 29)
my_dataset_test = MyDatasets(source_imgs_test, teaching_imgs_test)

loader_train = torch.utils.data.DataLoader(dataset=my_dataset_train, batch_size=mini_batch_size, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset=my_dataset_test, batch_size=mini_batch_size, shuffle=False)

In [None]:
def plt_data(x, y):
    plt.figure(figsize=(16,8))
    plt.subplot(1,12,1)
    plt.imshow(x[0], 'gray')
    for i in range(11):
        plt.subplot(1,12,2+i)
        plt.imshow(y[i], 'gray')
    plt.show()

In [None]:
plt.figure(figsize=(16,8))
for source_imgs, teaching_imgs in loader_test:
    print(source_imgs[0].shape)
    print(teaching_imgs[0].shape)
    for i in range(min(mini_batch_size, 10)):
        plt_data(source_imgs[i], teaching_imgs[i])
    break

In [None]:
def view(net, loader):
    num = 0;
    for x_imgs, t_imgs in loader:
        plt.figure(figsize=(16,8))
        x_buf = bb.FrameBuffer.from_numpy(np.array(x_imgs).astype(np.float32))
#       t0_buf = bb.FrameBuffer.from_numpy(np.array(t_imgs[:,0:10,:,:]).astype(np.float32))
#       t1_buf = bb.FrameBuffer.from_numpy(np.array(1.0 - t_imgs[:,10:11,:,:]).astype(np.float32)) 
        y0_buf, y1_buf = net.forward(x_buf, train=False)
        result_imgs0 = y0_buf.numpy()
        result_imgs1 = y1_buf.numpy()
        result_imgs = np.hstack((result_imgs0, result_imgs1))
        plt_data(x_imgs[0], result_imgs[0])
        num += 1
        if num >= 2: break

## ネットワーク構築

In [None]:
# バイナリ時は BIT型を使えばメモリ削減可能
bin_dtype = bb.DType.BIT if bin_mode else bb.DType.FP32

class DistillationConv(bb.Switcher):
    """蒸留用Conv層クラス

    Args:
        output_ch (int): 出力チャネル数
        hidden_ch (int): LUT-Net側の中間層チャネル数
        bin_dtype (DType): バイナリの型を bb.DType.FP32 と bb.DType.BIT から指定
    """
    
    def __init__(self, hidden_ch, output_ch, stage, *, filter_size=(3,3), padding='valid', input_shape=None, batch_norm=True, bin_dtype=bb.DType.FP32):
        name = 'DistillationConv_%d' % stage
        # Dense
        if batch_norm:
            self.dense_conv = bb.Convolution2d(
                   bb.Sequential([
                       bb.DenseAffine([output_ch], name='dense_affine_%d' % stage),
                       bb.BatchNormalization(name='dense_bn_%d' % stage),
                       bb.ReLU(name='dense_act_%d' % stage, bin_dtype=bin_dtype),
                   ]),
                   filter_size=filter_size,
                   padding=padding,
                   name='dense_conv_%d' % stage,
                   fw_dtype=bin_dtype)
        else:
            self.dense_conv = bb.Convolution2d(
                   bb.Sequential([
                       bb.DenseAffine([output_ch], name='dense_affine_%d' % stage),
                       bb.ReLU(name='dense_act_%d' % stage, bin_dtype=bin_dtype),
                   ]),
                   filter_size=filter_size,
                   padding=padding,
                   name='dense_conv_%d' % stage,
                   fw_dtype=bin_dtype)
            
        
        # LUT
        if filter_size[0] == 1 and filter_size[1] == 1:
            self.lut_conv = bb.Sequential([
                    # pointwise
                    bb.Convolution2d(
                       bb.Sequential([
                           bb.DifferentiableLut([hidden_ch*6, 1, 1], connection='random', batch_norm=batch_norm, name='lut_conv0_lut0_%d' % stage, bin_dtype=bin_dtype),
                           bb.DifferentiableLut([hidden_ch,   1, 1], connection='serial', batch_norm=batch_norm, name='lut_conv0_lut1_%d' % stage, bin_dtype=bin_dtype),
                           bb.DifferentiableLut([output_ch*6, 1, 1], connection='random', batch_norm=batch_norm, name='lut_conv0_lut2_%d' % stage, bin_dtype=bin_dtype),
                           bb.DifferentiableLut([output_ch,   1, 1], connection='serial', batch_norm=batch_norm, name='lut_conv0_lut3_%d' % stage, bin_dtype=bin_dtype),
                       ]),
                       filter_size=(1, 1),
                       name='lut_conv0_%d' % stage,
                       fw_dtype=bin_dtype),
                ])
        else :
            self.lut_conv = bb.Sequential([
                    # pointwise
                    bb.Convolution2d(
                       bb.Sequential([
                           bb.DifferentiableLut([hidden_ch*6, 1, 1], connection='random', batch_norm=batch_norm, name='lut_conv0_lut0_%d' % stage, bin_dtype=bin_dtype),
                           bb.DifferentiableLut([hidden_ch,   1, 1], connection='serial', batch_norm=batch_norm, name='lut_conv0_lut1_%d' % stage, bin_dtype=bin_dtype),
                       ]),
                       filter_size=(1, 1),
                       name='lut_conv0_%d' % stage,
                       fw_dtype=bin_dtype),

                    # depthwise
                    bb.Convolution2d(
                       bb.Sequential([
                           bb.DifferentiableLut([hidden_ch, 1, 1], connection='depthwise', batch_norm=batch_norm, name='lut_conv1_hidden0_%d' % stage, bin_dtype=bin_dtype),
                       ]),
                       filter_size=filter_size,
                       padding=padding,
                       name='lut_conv1_%d' % stage,
                       fw_dtype=bin_dtype),

                    # pointwise
                    bb.Convolution2d(
                       bb.Sequential([
                           bb.DifferentiableLut([output_ch*6, 1, 1], connection='random', batch_norm=batch_norm, name='lut_conv2_input0_%d' % stage, bin_dtype=bin_dtype),
                           bb.DifferentiableLut([output_ch,   1, 1], connection='serial', batch_norm=batch_norm, name='lut_conv2_input1_%d' % stage, bin_dtype=bin_dtype),
                       ]),
                       filter_size=(1, 1),
                       name='lut_conv2_%d' % stage,
                       fw_dtype=bin_dtype),
                ])
                
        model_dict = {}
        model_dict['dense'] = self.dense_conv
        model_dict['lut'] = self.lut_conv
        
        super(DistillationConv, self).__init__(model_dict=model_dict, init_model_name='dense', input_shape=input_shape, name=name)

    
class MyNetwork(bb.Sequential):
    def __init__(self):
        # convolutions
        self.net_cnv = bb.Sequential([])
        self.net_cnv.append(bb.RealToBinary(frame_modulation_size=frame_modulation_size, bin_dtype=bin_dtype))
        for i in range(29):
            self.net_cnv.append(DistillationConv(72, 36, stage=i, bin_dtype=bin_dtype))
        
        # classification
        self.net_cls = bb.Sequential([
                            DistillationConv(144, 10*depth_integration_size, filter_size=(1, 1), batch_norm=False, stage=29, bin_dtype=bin_dtype),
                            bb.BinaryToReal(frame_integration_size=frame_modulation_size, bin_dtype=bin_dtype),
                            bb.Reduce(integration_size=depth_integration_size),
                        ])
        
        # segmentation
        self.net_seg = bb.Sequential([
                            DistillationConv(72, 1*depth_integration_size, filter_size=(1, 1), batch_norm=False, stage=29, bin_dtype=bin_dtype),
                            bb.BinaryToReal(frame_integration_size=frame_modulation_size, bin_dtype=bin_dtype),
                            bb.Reduce(integration_size=depth_integration_size),
                        ])
        super(MyNetwork, self).__init__(model_list=[self.net_cnv, self.net_cls, self.net_seg])
    
#    def send_command(self, cmd):
#        self.net_cnv.send_command(cmd)
#        self.net_cls.send_command(cmd)
#        self.net_seg.send_command(cmd)
        
    def set_input_shape(self, shape):
        shape = self.net_cnv.set_input_shape(shape)
        self.net_cls.set_input_shape(shape)
        self.net_seg.set_input_shape(shape)
    
#    def get_info(self, depth=0):
#        return self.net_cnv.get_info(depth) + self.net_cls.get_info(depth) + self.net_seg.get_info(depth)
            
    def forward(self, x, train):
        x = self.net_cnv.forward(x, train)
        y0 = self.net_cls.forward(x, train)
        y1 = self.net_seg.forward(x, train)
        return y0, y1
    
    def backward(self, dy0, dy1):
        dy0 = self.net_cls.backward(dy0)
        dy1 = self.net_seg.backward(dy1)
        dx = self.net_cnv.backward(dy0 + dy1)
        return dx
        
net = MyNetwork()

net.set_input_shape([1, rows*28, cols*28])
net.set_name(net_name)
net.send_command("binary true")

In [None]:
print(net.get_info(2))

## 学習実施

学習を行います

In [None]:
def learning(net, epochs=2):
    # learning
    loss0     = bb.LossSoftmaxCrossEntropy()
    loss1     = bb.LossSigmoidCrossEntropy()
    metrics0  = bb.MetricsCategoricalAccuracy()
    metrics1  = bb.MetricsBinaryCategoricalAccuracy()
    optimizer = bb.OptimizerAdam()

    optimizer.set_variables(net.get_parameters(), net.get_gradients())

    for epoch in range(epochs):
        # learning
        loss0.clear()
        metrics0.clear()
        loss1.clear()
        metrics1.clear()
        with tqdm(loader_train) as tqdm_loadr:
            for x_imgs, t_imgs in tqdm_loadr:
                x_buf = bb.FrameBuffer.from_numpy(np.array(x_imgs).astype(np.float32))
                t0_buf = bb.FrameBuffer.from_numpy(np.array(t_imgs[:,0:10,:,:]).astype(np.float32))
                t1_buf = bb.FrameBuffer.from_numpy(1.0 - np.array(t_imgs[:,10:11,:,:]).astype(np.float32))
                y0_buf, y1_buf = net.forward(x_buf, train=True)

                dy0_buf = loss0.calculate(y0_buf, t0_buf)
                dy1_buf = loss1.calculate(y1_buf, t1_buf)
                metrics0.calculate(y0_buf, t0_buf)
                metrics1.calculate(y1_buf, t1_buf)

                net.backward(dy0_buf, dy1_buf)

                optimizer.update()

                tqdm_loadr.set_postfix(loss0=loss0.get(), acc0=metrics0.get(), loss1=loss1.get(), acc1=metrics1.get())

        # test
        loss0.clear()
        metrics0.clear()
        loss1.clear()
        metrics1.clear()
        for x_imgs, t_imgs in loader_test:
            x_buf = bb.FrameBuffer.from_numpy(np.array(x_imgs).astype(np.float32))
            t0_buf = bb.FrameBuffer.from_numpy(np.array(t_imgs[:,0:10,:,:]).astype(np.float32))
            t1_buf = bb.FrameBuffer.from_numpy(1.0 - np.array(t_imgs[:,10:11,:,:]).astype(np.float32))

            y0_buf, y1_buf = net.forward(x_buf, train=False)

            loss0.calculate(y0_buf, t0_buf)
            loss1.calculate(y1_buf, t1_buf)
            metrics0.calculate(y0_buf, t0_buf)
            metrics1.calculate(y1_buf, t1_buf)

        bb.save_networks(data_path, net)
        
        print('epoch[%d] : loss0=%f acc0=%f loss1=%f acc1=%f' % (epoch, loss0.get(), metrics0.get(), loss1.get(), metrics1.get()))
        view(net, loader_test)

In [None]:
bb.load_networks(data_path, net)

In [None]:
# DenseAffine で学習
#if not bb.load_networks(data_path, net, 'learn_dense'):
if True:
    net.send_command('switch_model dense')  # 全体を Denseモデルに切り替え
    learning(net, 32)
    bb.save_networks(data_path, net, 'learn_dense')
    view(net, loader_test)

In [None]:
bb.save_networks(data_path, net, 'learn_dense')

In [None]:
bb.load_networks(data_path, net, 'learn_dense')
view(net, loader_test)

In [None]:
print(net[0][1].get_info(2))

In [None]:
net.send_command('switch_model dense')
for i in range(29):
    print('----- layer %d -----'%i)
    net[0][i+1].send_command('switch_model lut')
    if not bb.load_networks(data_path, net, 'distilation%d'%i):
        # 1層のみ学習
        net.send_command('parameter_lock true')
        net[0][i+1].send_command('parameter_lock false')
        learning(net, epochs=2)
        
        # 手前の層をファインチューニング
        net.send_command('parameter_lock true')
        for j in range(29):
            if j > i:
                net[0][j+1].send_command('parameter_lock false')
        learning(net, epochs=2)
        bb.save_networks(data_path, net, 'distilation%d'%i)
    else:
        view(net, loader_test)

## RTL出力

Verilog 変換を行います

In [None]:
# export verilog
with open(output_velilog_file, 'w') as f:
    f.write('`timescale 1ns / 1ps\n\n')
    bb.dump_verilog_lut_cnv_layers(f, rtl_module_name, net)

# Simulation用ファイルに上書きコピー
shutil.copyfile(output_velilog_file, sim_velilog_file)