In [None]:
import os
import shutil
import random
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
#from tqdm import tqdm
import copy

import torch
import torchvision
import torchvision.transforms as transforms

import binarybrain as bb

In [None]:
bb.get_version_string()

In [None]:
# configuration
bb.set_device(0)

net_name               = 'mnist_iir_segmentation'
data_path              = os.path.join('./data/', net_name)

#rtl_sim_path           = '../../verilog/mnist'
#rtl_module_name        = 'MnistIirSemanticSegmentation'
#output_velilog_file    = os.path.join(data_path, net_name + '.v')
#sim_velilog_file       = os.path.join(rtl_sim_path, rtl_module_name + '.v')

bin_mode               = True
frame_modulation_size  = 3
depth_integration_size = 1
epochs                 = 8
mini_batch_size        = 16

In [None]:
# 並べるタイル数
rows=2
cols=2

img_h = 32
img_w = 32

def make_teacher_image(gen, rows, cols, margin=0):
    source_img  = np.zeros((1, rows*img_h, cols*img_w), dtype=np.float32)
    teaching_img = np.zeros((11, rows*img_h, cols*img_w), dtype=np.float32)
    for row in range(rows):
        for col in range(cols):
            x = col*img_w
            y = row*img_h
            img, label = gen.__next__()
            source_img[0,y:y+img_h,x:x+img_w] = img
            teaching_img[label,y:y+img_h,x:x+img_w] = img
            teaching_img[10,y:y+img_h,x:x+img_w] = (1.0-img)
    teaching_img = (teaching_img > 0.5).astype(np.float32)
    
    # ランダムに反転
    if random.random() > 0.5:
        source_img = 1.0 - source_img
    
    if margin > 0:
        return source_img, teaching_img[:,margin:-margin,margin:-margin]
    return source_img, teaching_img        

def transform_data(dataset, n, rows, cols, margin):
    def data_gen():
        l = len(dataset)
        i = 0
        while True:
            yield dataset[i%l]
            i += 1
    
    gen = data_gen()
    source_imgs = []
    teaching_imgs = []
    for _ in range(n):
        x, t = make_teacher_image(gen, rows, cols, margin)
        source_imgs.append(x)
        teaching_imgs.append(t)
    return source_imgs, teaching_imgs

class MyDatasets(torch.utils.data.Dataset):
    def __init__(self, source_imgs, teaching_imgs, transforms=None):
        self.transforms = transforms
        self.source_imgs = source_imgs
        self.teaching_imgs = teaching_imgs
        
    def __len__(self):
        return len(self.source_imgs)

    def __getitem__(self, index):
        source_img = self.source_imgs[index]
        teaching_img = self.teaching_imgs[index]
        if self.transforms:
            source_img, teaching_img = self.transforms(source_img, teaching_img)
        return source_img, teaching_img

transform = transforms.Compose([
                transforms.Resize((img_w, img_h)),
                transforms.ToTensor(),
            ])
    
# dataset
dataset_path = './data/'
dataset_train = torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transform, download=True)
dataset_test  = torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transform, download=True)

# 面積の比率で重み作成
areas = np.zeros((11))
for img, label in dataset_train:
    img = img.numpy()
    areas[label] += np.mean(img)
    areas[10] += np.mean(1.0-img)
areas /= len(dataset_train)

weight = 1. / areas
weight /= np.max(weight)

# フィルタ処理用にタイル化する
dataset_fname = os.path.join(data_path, 'dataset.pickle')
if os.path.exists(dataset_fname):
#if False:
    with open(dataset_fname, 'rb') as f:
        source_imgs_train = pickle.load(f)
        teaching_imgs_train = pickle.load(f)
        source_imgs_test = pickle.load(f)
        teaching_imgs_test = pickle.load(f)
else:
    os.makedirs(data_path, exist_ok=True)
    source_imgs_train, teaching_imgs_train = transform_data(dataset_train, 4096, rows, cols, 0) #29)
    source_imgs_test, teaching_imgs_test = transform_data(dataset_test, 128, rows, cols, 0) #, 29)
    with open(dataset_fname, 'wb') as f:
        pickle.dump(source_imgs_train, f)
        pickle.dump(teaching_imgs_train, f)
        pickle.dump(source_imgs_test, f)
        pickle.dump(teaching_imgs_test, f)

my_dataset_train = MyDatasets(source_imgs_train, teaching_imgs_train)
my_dataset_test = MyDatasets(source_imgs_test, teaching_imgs_test)

loader_train = torch.utils.data.DataLoader(dataset=my_dataset_train, batch_size=mini_batch_size, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset=my_dataset_test, batch_size=mini_batch_size, shuffle=False)

In [None]:
# 学習データ表示確認
def plt_data(x, y):
    plt.figure(figsize=(16,8))
    plt.subplot(1,12,1)
    plt.title('sorce')
    plt.imshow(x[0], 'gray')
    for i in range(11):
        plt.subplot(1,12,2+i)
        if i < 10:
            plt.title('class=%d'%i)
            plt.imshow(y[i], 'gray')
        else:
            plt.title('background')
            plt.imshow(y[i], 'gray')
    plt.tight_layout()
    _ = plt.show()

plt.figure(figsize=(16,8))
for source_imgs, teaching_imgs in loader_test:
    print(source_imgs[0].shape)
    print(teaching_imgs[0].shape)
    for i in range(min(mini_batch_size, 4)):
        plt_data(source_imgs[i], teaching_imgs[i])
    break

In [None]:
# バイナリ時は BIT型を使えばメモリ削減可能
bin_dtype = bb.DType.BIT if bin_mode else bb.DType.FP32

def create_dense_affine(name, output_ch, fw_dtype=bin_dtype):
    """バイナリ化したDenseAffine層生成"""
    return bb.Sequential([
                bb.DenseAffine([output_ch, 1, 1], name=(name + '_dense_affine')),
                bb.BatchNormalization(name=(name + '_dense_bn')),
                bb.Binarize(name=(name + '_dense_act'), bin_dtype=fw_dtype),
            ])

def create_dense_conv(name, output_ch, filter_size=(3, 3), padding='same', fw_dtype=bin_dtype):
    """バイナリ化したDenseConv層生成"""
    return bb.Convolution2d(
                create_dense_affine(name, output_ch, fw_dtype),
                filter_size=filter_size,
                padding=padding,
                name=(name + '_dense_conv'),
                fw_dtype=fw_dtype)

def create_conv_block(name, output_ch=64):
    return bb.Sequential([
                create_dense_conv(name + '_cnv0', output_ch),
                create_dense_conv(name + '_cnv1', output_ch),
            ])


class ScaledNetwork(bb.Sequential):
    def __init__(self, name, hidden_ch=32, output_ch=32, top=False):
        self.hidden_ch = hidden_ch
        self.output_ch = output_ch
        self.top       = top
        
        self.up   = bb.UpSampling((2, 2), fw_dtype=bin_dtype)    
        self.pool = bb.MaxPooling((2, 2), fw_dtype=bin_dtype)
        self.cat0 = bb.Concatenate()
        self.cat1 = bb.Concatenate()
        self.cnv0 = create_conv_block(name + '_cnv0', self.hidden_ch)
        self.cnv1 = create_conv_block(name + '_cnv1', self.output_ch)
        super(ScaledNetwork, self).__init__([self.up, self.pool, self.cat0, self.cat1, self.cnv0, self.cnv1])
    
    def set_input_shape(self, x0_shape, x1_shape, u_shape):
        x1_shape = self.up.set_input_shape(x1_shape)
        x_shape = self.cat0.set_input_shape([x0_shape, x1_shape])
        x_shape = self.cnv0.set_input_shape(x_shape)
        v_shape = self.pool.set_input_shape(x_shape)
        if not self.top:
            x_shape = self.cat1.set_input_shape([u_shape, x_shape])
        y_shape = self.cnv1.set_input_shape(x_shape)
        return y_shape, v_shape
    
    def forward(self, x0, x1, u, train=True):
        x1 = self.up.forward(x1, train=train)
        x = self.cat0.forward([x0, x1])
        x = self.cnv0.forward(x, train=train)        
        v = self.pool.forward(x, train=train)
        if not self.top:
            x = self.cat1.forward([u, x])
        y = self.cnv1.forward(x, train=train)
        return y, v
    
    def backward(self, dy, dv):
        dy = self.cnv1.backward(dy)
        if not self.top:
            du, dy0 = self.cat1.backward([dy])
        else:
            du, dy0 = None, dy
        dy1 = self.pool.backward(dv)
        dy = self.cnv0.backward(dy0 + dy1)
        dx0, dx1 = self.cat0.backward([dy])
        dx1 = self.up.backward(dx1)
        return dx0, dx1, du

In [None]:
class MipmapNetwork(bb.Sequential):
    def __init__(self, loop=6, ch=32, depth=3):
        self.loop    = loop
        self.depth   = depth
        self.ch      = ch
        self.shape   = None
        self.up      = bb.UpSampling((2, 2))
        self.m_net = ScaledNetwork('main', ch, 11, top=True)
        self.s_net = ScaledNetwork('sub', ch, ch)
        super(MipmapNetwork, self).__init__([self.m_net, self.s_net])
    
    def set_input_shape(self, shape):
        self.shape = copy.copy(shape)
        
        x0_shape = copy.copy(shape)
        x1_shape = copy.copy(shape)
        x1_shape[0] = self.ch
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        u_shape = copy.copy(shape)
        y_shape, v_shape = self.m_net.set_input_shape(x0_shape, x1_shape, u_shape)
        
        x0_shape = copy.copy(x1_shape)
        x0_shape[0] = self.ch
        x1_shape = copy.copy(x1_shape)
        x1_shape[0] = self.ch
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        self.s_net.set_input_shape(x0_shape, x1_shape, v_shape)
        
        return y_shape
    
    def make_mipmap(self, frame_size, dtype=bb.DType.FP32):
        h = self.shape[1]
        w = self.shape[2]
        mipmap = []
        for i in range(self.depth+1):
            h //= 2
            w //= 2
            buf = bb.FrameBuffer(frame_size, (self.ch, h, w), dtype=dtype)
            buf.fill_zero()
            mipmap.append(buf)
        return mipmap
    
    def forward(self, x, train=True):
        self.m_net.clear()
        self.s_net.clear()
        
        frame_size = x.get_frame_size()
        
#       u_shape = x.get_node_shape()
#       u_shape[0] = self.ch
#       u = bb.FrameBuffer(frame_size, u_shape, dtype=bb.DType.BIT)
#       u.fill_zero()
        
        mipmap = self.make_mipmap(frame_size, dtype=bin_dtype)        
        for _ in range(self.loop):
            y, v = self.m_net.forward(x, mipmap[0], None, train=train)
            for i in range(self.depth):
                mipmap[i], v = self.s_net.forward(mipmap[i], mipmap[i+1], v)
        
        self.v_shape = v.get_node_shape()
        return y
    
    def backward(self, dy):
        frame_size = dy.get_frame_size()
        dv = bb.FrameBuffer(frame_size, self.v_shape, dtype=bb.DType.FP32)
        dv.fill_zero()
        
        mipmap = self.make_mipmap(frame_size, dtype=bb.DType.FP32)
        for _ in range(self.loop):
            du = dv
            for i in reversed(range(self.depth)):
                dx0, dx1, du = self.s_net.backward(mipmap[i], du)
                mipmap[i]   += dx0
                mipmap[i+1] += dx1
            dx0, dx1, du = self.m_net.backward(dy, du)
            mipmap[0] += dx1
        return dx0

In [None]:
net = MipmapNetwork()
net.set_input_shape([1, img_h*rows, img_w*cols])

In [None]:
# learning
if True:
    loss      = bb.LossSoftmaxCrossEntropy()
    metrics   = bb.MetricsCategoricalAccuracy()
    optimizer = bb.OptimizerAdam()

    optimizer.set_variables(net.get_parameters(), net.get_gradients())

    real2bin = bb.RealToBinary(frame_modulation_size=frame_modulation_size, bin_dtype=bin_dtype)
    bin2real = bb.BinaryToReal(frame_integration_size=frame_modulation_size, bin_dtype=bin_dtype)

    for epoch in range(epochs):
        # learning
        loss.clear()
        metrics.clear()
        with tqdm(loader_train) as tqdm_loadr:
            for x_imgs, t_imgs in tqdm_loadr:
                x_buf = bb.FrameBuffer.from_numpy(np.array(x_imgs).astype(np.float32))
                t_buf = bb.FrameBuffer.from_numpy(np.array(t_imgs).astype(np.float32))

                x_buf = real2bin.forward(x_buf, train=True)
#               print('fw:', x_buf.get_frame_size(), x_buf.get_node_shape())
                y_buf = net.forward(x_buf, train=True)
#               print('fw:', y_buf.get_frame_size(), y_buf.get_node_shape())
                y_buf = bin2real.forward(y_buf, train=True)

                dy_buf = loss.calculate(y_buf, t_buf)
                metrics.calculate(y_buf, t_buf)
                
#               print('bw:', dy_buf.get_frame_size(), dy_buf.get_node_shape())
                dy_buf = bin2real.backward(dy_buf)
                net.backward(dy_buf)
                optimizer.update()

                tqdm_loadr.set_postfix(loss=loss.get(), acc=metrics.get())

In [None]:
-----------------------------------

In [None]:
x = bb.FrameBuffer(48, (1, 32*2, 32*2), dtype=bb.DType.BIT)
y_shape = net.set_input_shape(x.get_node_shape())
print(x.get_node_shape())
print(y_shape)

In [None]:
for _ in range(10):
    print(x.get_frame_size(), x.get_node_shape())
    y = net.forward(x, train=True)
    print(y.get_frame_size(), y.get_node_shape())

    dy = bb.FrameBuffer(y.get_frame_size(), y.get_node_shape(), dtype=bb.DType.FP32)

    print(dy.get_frame_size(), dy.get_node_shape())
    dx = net.backward(dy)
    print(dx.get_frame_size(), dx.get_node_shape())

In [None]:
-----------------

In [None]:
class MipmapNetwork(bb.Sequential):
    def __init__(self, loop=8, ch=32, depth=4):
        self.loop    = loop
        self.depth   = depth
        self.ch      = ch
        self.shape   = None
        self.up      = bb.UpSampling((2, 2))
        self.m_net = ScaledNetwork('main', ch, 11)
        self.s_net = ScaledNetwork('sub', ch, ch)
        super(MipmapNetwork, self).__init__([self.m_net, self.s_net])
    
    def set_input_shape(self, u_shape):
        x0_shape = copy.copy(u_shape)
        x0_shape[0] = self.ch
        x1_shape = copy.copy(u_shape)
        x1_shape[0] = self.ch
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        self.shape = x0_shape
        self.s_net.set_input_shape(x0_shape, x1_shape, u_shape)
        v_shape = copy.copy(u_shape)
        for i in range(self.depth):
            v_shape[1] //= 2
            v_shape[2] //= 2
        return v_shape
    
    def make_mipmap(self, frame_size, dtype=bb.DType.FP32):
        h = self.shape[1]
        w = self.shape[2]
        mipmap = []
        for i in range(self.depth+1):
            buf = bb.FrameBuffer(frame_size, (self.ch, h, w), dtype=dtype)
            buf.fill_zero()
            mipmap.append(buf)
            h //= 2
            w //= 2
        return mipmap
    
    def forward(self, u, train=True):
        self.m_net.clear()
        self.s_net.clear()
        mipmap = self.make_mipmap(u.get_frame_size(), dtype=bin_dtype)
        for _ in range(self.loop):
            u0 = u
            for i in range(self.depth):
                mipmap[i], u0 = self.s_net.forward(mipmap[i], mipmap[i+1], u0)
        return u0
    
    def backward(self, dv):
        mipmap = self.make_mipmap(dv.get_frame_size())
        for _ in range(self.loop):
            dv0 = dv
            for i in reversed(range(self.depth)):
                dx0, dx1, dv0 = self.sub_net.backward(mipmap[i], dv0)
                mipmap[i]   += dx0
                mipmap[i+1] += dx1
        return dv0

In [None]:
mip_net = MipmapNetwork()

In [None]:
u = bb.FrameBuffer(32, (32, 32*3//2, 32*3//2), dtype=bb.DType.BIT)
mip_net.set_input_shape(u.get_node_shape())

In [None]:
#print(mip_net.sub_net.cnv0[0].get_info())
#print(mip_net.sub_net.cnv1[0].get_info())

In [None]:
v = mip_net.forward(u)
print(v.get_node_shape())

In [None]:
dv = bb.FrameBuffer(v.get_frame_size(), v.get_node_shape(), dtype=bb.DType.FP32)
du = mip_net.backward(dv)

In [None]:
print(du.get_frame_size(), du.get_node_shape())

In [None]:
-------------

In [None]:
class SemanticSegmentationIirNetwork(bb.Sequential):
    """セグメンテーション＋分類ネットワーク"""
    def __init__(self):
        self.ch    = 32
        self.depth = 4

        self.real2bin = bb.RealToBinary(frame_modulation_size=frame_modulation_size, bin_dtype=bin_dtype)
        self.bin2real = bb.BinaryToReal(frame_integration_size=frame_modulation_size, bin_dtype=bin_dtype)
        
        self.pool  = bb.MaxPooling((2, 2))
        self.up    = bb.UpSampling((2, 2))
        
        # 等倍用
        self.cnv0 = create_conv_block('m_cnv0', self.ch)
        self.cnv1 = create_conv_block('m_cnv1', self.ch)
        
        # ミップマップ用
        self.mipmap = MipmapNetwork(loop=8, ch=32, depth=4)
        
        super(SemanticSegmentationIirNetwork, self).__init__([self.real2bin, self.bin2real, self.pool, self.up, self.cnv0, self.cnv1, self.mipmap])
        
    def set_input_shape(self, shape):
        shape = self.real2bin.set_input_shape(shape)
        shape[0] += 1
        self.m_cnv0.set_input_shape(shape)
        
        shape = self.net_cnv.set_input_shape(shape)
        shape = self.bin2real.set_input_shape(shape)
        return shape
    
    def forward(self, x, train=True):
        shape = x.get_node_shape()
        n = x.get_frame_size()
        c = shape[0]
        h = shape[1]
        w = shape[2]
        
        # mipmap作成
        mipmap = []
        for i in range(self.depth):
            h //= 2
            w //= 2
            buf = bb.FrameBuffer(n, (self.ch, h, w), dtype=bin_dtype)
#           buf.fill_zero()
            mipmap.append(buf)
        
        # フィルタ効果が出るまで何回かループする
        x = self.real2bin.forward(x, train=train)
        for loop in range(3):
            # ミップマップの一番大きいものと concatして推論
            y = self.m_cnv0.forward(buf_concat(x, self.up.forward(mipmap[0], train=train)), train=train)
            x0 = self.pool.forward(y, train=train)
            y = self.m_cnv1.forward(y, train=train)
            
            # ミップマップの上下階層をconcatしてそれぞれ同じネットで推論
            for i in range(self.depth):
                if i+1 < self.depth:
                    x1 = self.s_cnv0.forward(buf_concat(mipmap[i], self.up.forward(mipmap[i+1], train=train)), train=train)
                else:
                    x1 = self.s_cnv0.forward(buf_concat(mipmap[i], mipmap[i]), train=train) # 最下層
                mipmap[i] = self.s_cnv1.forward(buf_concat(x1, x0), train=train)
                x0 = self.pool.forward(x1, train=train)
        
        y = self.bin2real.forward(y, train=train)
        return y
    
    def backward(self, dy):
        shape = dy.get_node_shape()
        n = dy.get_frame_size()
        c = shape[0]
        h = shape[1]
        w = shape[2]
        
        # mipmap作成
        mipmap = []
        for i in range(self.depth):
            h //= 2
            w //= 2
            buf = bb.FrameBuffer(n, (self.ch, h, w), dtype=bin_dtype)
#           buf.fill_zero()
            mipmap.append(buf)
        dy = self.bin2real.backward(dy)
        
        dy0 = bb.FrameBuffer(n, (self.ch, h, w))
#       dy0.fill_zero()
        for loop in reversed(range(3)):
            for i in reversed(self.depth):
                dy0 = self.pool.backward(dy0)
                dy0, dy1 = buf_split(self.s_cnv1.backward(mipmap[i]), self.ch//2)
                    
                    buf_concat(x1, x0), train=train)
                
      
        return dy