# MNISTでセグメンテーションに挑戦



In [None]:
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torchvision
import torchvision.transforms as transforms

import binarybrain as bb

## 初期設定

In [None]:
# configuration
bb.set_device(0)

net_name              = 'MnistSegmentationDistillation'
data_path             = os.path.join('./data/', net_name)

bin_mode              = True
frame_modulation_size = 3
epochs                = 0
mini_batch_size       = 8

## データセット準備

データセットを自作する

In [None]:
# dataset
dataset_path = './data/'
dataset_train = torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transforms.ToTensor(), download=True)
dataset_test  = torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transforms.ToTensor(), download=True)

rows=3
cols=3

# 面積の比率で重みを作っておく
areas = np.zeros((11))
for img, label in dataset_train:
    img = img.numpy()
    areas[label] += np.mean(img)
    areas[10] += np.mean(1.0-img)
areas /= len(dataset_train)

wight = 1 / areas
wight /= np.max(wight)

def make_teacher_image(gen, rows, cols, margin=0):
    source_img  = np.zeros((1, rows*28, cols*28), dtype=np.float32)
    teaching_img = np.zeros((11, rows*28, cols*28), dtype=np.float32)
    for row in range(rows):
        for col in range(cols):
            x = col*28
            y = row*28
            img, label = gen.__next__()
            source_img[0,y:y+28,x:x+28] = img
            teaching_img[label,y:y+28,x:x+28] = img
            teaching_img[10,y:y+28,x:x+28] = (1.0-img)
    msk1 = teaching_img > 0.5
    msk0 = teaching_img <= 0.5
    teaching_img[msk1] = 1.0
    teaching_img[msk0] = 0.0
    
    teaching_img[10] *= 0.1
#    for i in range(11):
#        teaching_img[i] *= wight[i]
    
    return source_img, teaching_img[:,margin:-margin,margin:-margin]

def transform_data(dataset, n, rows, cols, margin):
    def data_gen():
        l = len(dataset)
        i = 0
        while True:
            yield dataset[i%l]
            i += 1
    
    gen = data_gen()
    source_imgs = []
    teaching_imgs = []
    for _ in range(n):
        x, t = make_teacher_image(gen, rows, cols, margin)
        source_imgs.append(x)
        teaching_imgs.append(t)
    return source_imgs, teaching_imgs

class MyDatasets(torch.utils.data.Dataset):
    def __init__(self, source_imgs, teaching_imgs, transforms=None):
        self.transforms = transforms
        self.source_imgs = source_imgs
        self.teaching_imgs = teaching_imgs
        
    def __len__(self):
        return len(self.source_imgs)

    def __getitem__(self, index):
        source_img = self.source_imgs[index]
        teaching_img = self.teaching_imgs[index]
        if self.transforms:
            source_img, teaching_img = self.transforms(source_img, teaching_img)
        return source_img, teaching_img


#source_imgs_train, teaching_imgs_train = transform_data(dataset_train, mini_batch_size*1000, rows, cols, 29)
source_imgs_train, teaching_imgs_train = transform_data(dataset_train, mini_batch_size*5, rows, cols, 29)
my_dataset_train = MyDatasets(source_imgs_train, teaching_imgs_train)

source_imgs_test, teaching_imgs_test = transform_data(dataset_test, mini_batch_size*5, rows, cols, 29)
my_dataset_test = MyDatasets(source_imgs_test, teaching_imgs_test)

loader_train = torch.utils.data.DataLoader(dataset=my_dataset_train, batch_size=mini_batch_size, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset=my_dataset_test, batch_size=mini_batch_size, shuffle=False)

In [None]:
def plt_data(x, y):
    plt.figure(figsize=(16,8))
    plt.subplot(1,12,1)
    plt.imshow(x[0], 'gray')
    for i in range(11):
        plt.subplot(1,12,2+i)
        plt.imshow(y[i], 'gray')
    plt.show()

In [None]:
plt.figure(figsize=(16,8))
for source_imgs, teaching_imgs in loader_test:
    print(source_imgs[0].shape)
    print(teaching_imgs[0].shape)
    for i in range(min(mini_batch_size, 5)):
        plt_data(source_imgs[i], teaching_imgs[i])
    break

## ネットワーク構築

In [None]:
# バイナリ時は BIT型を使えばメモリ削減可能
bin_dtype = bb.DType.BIT if bin_mode else bb.DType.FP32

class DistillationConv(bb.Switcher):
    """蒸留用クラス

    Args:
        output_ch (int): 出力チャネル数
        hidden_ch (int): LUT-Net側の中間層チャネル数
        bin_dtype (DType): バイナリの型を bb.DType.FP32 と bb.DType.BIT から指定
    """
    
    def __init__(self, hidden_ch, output_ch, stage, *, filter_size=(3,3), padding='valid', input_shape=None, bin_dtype=bb.DType.FP32):
        name = 'DistillationConv_%d' % stage
        # Dense
        self.dense_conv = bb.Convolution2d(
               bb.Sequential([
                   bb.DenseAffine([output_ch], name='dense_affine_%d' % stage),
                   bb.BatchNormalization(name='dense_bn_%d' % stage),
                   bb.ReLU(name='dense_act_%d' % stage, bin_dtype=bin_dtype),
               ]),
               filter_size=filter_size,
               padding=padding,
               name='dense_conv_%d' % stage,
               fw_dtype=bin_dtype)
        
        # LUT
        self.lut_conv = bb.Sequential([
                # pointwise
                bb.Convolution2d(
                   bb.Sequential([
                       bb.DifferentiableLut([hidden_ch*6, 1, 1], connection='random', name='lut_conv0_lut0_%d' % stage, bin_dtype=bin_dtype),
                       bb.DifferentiableLut([hidden_ch,   1, 1], connection='serial', name='lut_conv0_lut1_%d' % stage, bin_dtype=bin_dtype),
                   ]),
                   filter_size=(1, 1),
                   name='lut_conv0_%d' % stage,
                   fw_dtype=bin_dtype),
            
                # depthwise
                bb.Convolution2d(
                   bb.Sequential([
                       bb.DifferentiableLut([hidden_ch, 1, 1], connection='depthwise', name='lut_conv1_hidden0_%d' % stage, bin_dtype=bin_dtype),
                   ]),
                   filter_size=filter_size,
                   padding=padding,
                   name='lut_conv1_%d' % stage,
                   fw_dtype=bin_dtype),

                # pointwise
                bb.Convolution2d(
                   bb.Sequential([
                       bb.DifferentiableLut([output_ch*6, 1, 1], connection='random', name='lut_conv2_input0_%d' % stage, bin_dtype=bin_dtype),
                       bb.DifferentiableLut([output_ch,   1, 1], connection='serial', name='lut_conv2_input1_%d' % stage, bin_dtype=bin_dtype),
                   ]),
                   filter_size=(1, 1),
                   name='lut_conv2_%d' % stage,
                   fw_dtype=bin_dtype),
            ])
                
        model_dict = {}
        model_dict['dense'] = self.dense_conv
        model_dict['lut'] = self.lut_conv
        
        super(DistillationConv, self).__init__(model_dict=model_dict, init_model_name='dense', input_shape=input_shape, name=name)
        

sub_net = bb.Sequential()
for i in range(27):
    if i < 8:
        sub_net.append(DistillationConv(36*2, 36, stage=i, bin_dtype=bin_dtype))
    else:
        sub_net.append(DistillationConv(72*2, 72, stage=i, bin_dtype=bin_dtype))
        
sub_net.append(
        bb.Convolution2d(
            bb.Sequential([
                bb.DenseAffine([512], name='dense_affine_fc0'),
                bb.BatchNormalization(name='dense_bn_fc0'),
                bb.ReLU(bin_dtype=bin_dtype, name='dense_act_fc0'),
                bb.DenseAffine([11], name='dense_affine_fc1'),
                bb.BatchNormalization(name='dense_bn_fc1'),
                bb.ReLU(bin_dtype=bin_dtype, name='dense_act_fc1'),
            ]),
            filter_size=(5, 5),
            name='dense_conv_fc',
            fw_dtype=bin_dtype))
    
# define network
if bin_mode:
    net = bb.Sequential([
                bb.RealToBinary(frame_modulation_size=frame_modulation_size, bin_dtype=bin_dtype),
                sub_net,
                bb.BinaryToReal(frame_integration_size=frame_modulation_size, bin_dtype=bin_dtype)
            ])
else:
    net = sub_net
    
net.set_input_shape([1, rows*28, cols*28])

net.set_name(net_name)

if bin_mode:
    net.send_command("binary true")

print(net.get_info(1))

## 学習実施

学習を行います

In [None]:
def view(net, loader):
    num = 0;
    for x_imgs, t_imgs in loader:
        plt.figure(figsize=(16,8))
        x_buf = bb.FrameBuffer.from_numpy(np.array(x_imgs).astype(np.float32))
        t_buf = bb.FrameBuffer.from_numpy(np.array(t_imgs).astype(np.float32))
        y_buf = net.forward(x_buf, train=False)
        result_imgs = y_buf.numpy()
        plt_data(x_imgs[0], result_imgs[0])
        num += 1
        if num >= 2: break

In [None]:
def learning(epochs=2):
    # learning
    loss      = bb.LossSoftmaxCrossEntropy() # MyLoss(loss_weight) # 
    #loss      = bb.LossMeanSquaredError()
    metrics   = bb.MetricsCategoricalAccuracy()
    optimizer = bb.OptimizerAdam()

    optimizer.set_variables(net.get_parameters(), net.get_gradients())

    for epoch in range(epochs):
#       view(net, loader_test)

        # learning
        loss.clear()
        metrics.clear()
        with tqdm(loader_train) as tqdm_loadr:
            for x_imgs, t_imgs in tqdm_loadr:
                x_buf = bb.FrameBuffer.from_numpy(np.array(x_imgs).astype(np.float32))
                t_buf = bb.FrameBuffer.from_numpy(np.array(t_imgs).astype(np.float32))

                y_buf = net.forward(x_buf, train=True)

                dy_buf = loss.calculate(y_buf, t_buf)
                metrics.calculate(y_buf, t_buf)

                net.backward(dy_buf)

                optimizer.update()

                tqdm_loadr.set_postfix(loss=loss.get(), acc=metrics.get())

        # test
        loss.clear()
        metrics.clear()
        for x_imgs, t_imgs in loader_test:
            x_buf = bb.FrameBuffer.from_numpy(np.array(x_imgs).astype(np.float32))
            t_buf = bb.FrameBuffer.from_numpy(np.array(t_imgs).astype(np.float32))

            y_buf = net.forward(x_buf, train=False)

            loss.calculate(y_buf, t_buf)
            metrics.calculate(y_buf, t_buf)

        bb.save_networks(data_path, net, force_flatten=True)

        print('epoch[%d] : loss=%f acc=%f' % (epoch, loss.get(), metrics.get()))

## 保存実験

In [None]:
if False:
    # ベースを読み込む
    bb.load_models(os.path.join(data_path, '20210106_185107_base'), net, read_layers=True, file_format='bin')
    
    if False:
        # いろいろな構造で保存
        bb.save_models(os.path.join(data_path, 'test_bb_net_layers'), net, write_layers=True, file_format='bb_net')

        bb.save_models(os.path.join(data_path, 'test_bb_net'), net, write_layers=False, file_format='bb_net')

        bb.save_models(os.path.join(data_path, 'test_pickle'), net, write_layers=False, file_format='pickle')
        
        bb.save_models(os.path.join(data_path, 'test_pickle_layers'), net, write_layers=True, file_format='pickle')

        import pickle
        with open(os.path.join(data_path, 'MnistSegmentationDistillation_test.pickle'), 'wb') as f:
            pickle.dump(net, f)    

## 読出し色々

In [None]:
#bb.load_networks(data_path, net, read_layers=True, force_flatten=True)
## bb.load_networks(data_path, net, read_layers=False, force_flatten=True)

In [None]:
# bb.save_networks(data_path, net, read_layers=True, backups=-1)
# bb.save_networks(data_path, net, read_layers=True, backups=-1, file_format='pickle')

In [None]:
if False:
    bb.load_models(os.path.join(data_path, 'test_bb_net'), net)

In [None]:
if True:
    bb.load_models(os.path.join(data_path, 'test_bb_net_layers'), net, read_layers=True, file_format='bb_net')

In [None]:
if False:
    bb.load_models(os.path.join(data_path, 'test_pickle'), net, file_format='pickle')

In [None]:
if False:
    bb.load_models(os.path.join(data_path, 'test_pickle_layers'), net, read_layers=True, file_format='pickle')

In [None]:
if False:
    import pickle
    with open(os.path.join(data_path, 'MnistSegmentationDistillation_test.pickle'), 'rb') as f:
        net = pickle.load(f)
    sub_net = net[1]

## ちゃんと読めてるか確認

In [None]:
# DenseAffine で学習
#learning(32)
learning(0)
view(net, loader_test)

In [None]:
# bb.save_networks(data_path, net, backups=3, write_layers=True, force_flatten=True)

In [None]:
# 下の層から順に切り替えて表示
for i in range(27):
    print("----- layer %d -----"%i)
    sub_net[i].send_command('switch_model lut')
    net.send_command('parameter_lock true')
    sub_net[i].send_command('parameter_lock false')
#   learning(epochs=2)
    view(net, loader_test)

In [None]:
# 下の層から順に蒸留
for i in range(27):
    sub_net[i].send_command('switch_model lut')
    net.send_command('parameter_lock true')
    sub_net[i].send_command('parameter_lock false')
    learning(epochs=2)
    view(net, loader_test)

In [None]:
a = {1:2, 2:3, 14:111}

In [None]:
len(a)

In [None]:
str_b = str_a.encode(encoding='utf-8')

In [None]:
len(str_b)

In [None]:
str_b

In [None]:
str_b.decode(encoding='utf-8')

In [None]:
a = False

In [None]:
a.to_bytes(1, 'little')