In [None]:
import os
import shutil
import random
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
#from tqdm import tqdm
import copy

import torch
import torchvision
import torchvision.transforms as transforms

import binarybrain as bb

In [None]:
bb.get_version_string()

In [None]:
bin_mode = True
epochs = 8

loops = 4
depth = 3
ch = 32

# 並べるタイル数
rows=2
cols=2
img_h = 32
img_w = 32

data_path = './data/mnist_iir_segmentation'
mini_batch_size = 16

## 学習データ作成

In [None]:
# dataset
transform = transforms.Compose([
                transforms.Resize((img_w, img_h)),
                transforms.ToTensor(),
            ])

dataset_path = './data/'
dataset_train = torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transform, download=True)
dataset_test  = torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transform, download=True)


def make_teacher_image(gen, rows, cols, margin=0):
    source_img  = np.zeros((1, rows*img_h, cols*img_w), dtype=np.float32)
    teaching_img = np.zeros((11, rows*img_h, cols*img_w), dtype=np.float32)
    for row in range(rows):
        for col in range(cols):
            x = col*img_w
            y = row*img_h
            img, label = gen.__next__()
            source_img[0,y:y+img_h,x:x+img_w] = img
            teaching_img[label,y:y+img_h,x:x+img_w] = img
            teaching_img[10,y:y+img_h,x:x+img_w] = (1.0-img)
    teaching_img = (teaching_img > 0.5).astype(np.float32)
    
    # 面積で重みを載せる
    for i in range(11):
        teaching_img[i] *= weight[i]
    
    # ランダムに反転
    if random.random() > 0.5:
        source_img = 1.0 - source_img

    if margin > 0:
        return source_img, teaching_img[:,margin:-margin,margin:-margin]
    return source_img, teaching_img        

def transform_data(dataset, n, rows, cols, margin):
    def data_gen():
        l = len(dataset)
        i = 0
        while True:
            yield dataset[i%l]
            i += 1
    
    gen = data_gen()
    source_imgs = []
    teaching_imgs = []
    for _ in range(n):
        x, t = make_teacher_image(gen, rows, cols, margin)
        source_imgs.append(x)
        teaching_imgs.append(t)
    return source_imgs, teaching_imgs

class MyDatasets(torch.utils.data.Dataset):
    def __init__(self, source_imgs, teaching_imgs, transforms=None):
        self.transforms = transforms
        self.source_imgs = source_imgs
        self.teaching_imgs = teaching_imgs
        
    def __len__(self):
        return len(self.source_imgs)

    def __getitem__(self, index):
        source_img = self.source_imgs[index]
        teaching_img = self.teaching_imgs[index]
        if self.transforms:
            source_img, teaching_img = self.transforms(source_img, teaching_img)
        return source_img, teaching_img

# フィルタ処理用にタイル化する
dataset_fname = os.path.join(data_path, 'dataset.pickle')
if os.path.exists(dataset_fname):
#if False:
    with open(dataset_fname, 'rb') as f:
        source_imgs_train = pickle.load(f)
        teaching_imgs_train = pickle.load(f)
        source_imgs_test = pickle.load(f)
        teaching_imgs_test = pickle.load(f)
        weight = pickle.load(f)
else:
    # 面積の比率で重み作成
    areas = np.zeros((11))
    for img, label in dataset_train:
        img = img.numpy()
        areas[label] += np.mean(img)
        areas[10] += np.mean(1.0-img)
    areas /= len(dataset_train)
    
    weight = 1 / areas
    weight /= np.max(weight)
    
    source_imgs_train, teaching_imgs_train = transform_data(dataset_train, 4096, rows, cols, 0) #29)
    source_imgs_test, teaching_imgs_test = transform_data(dataset_test, 128, rows, cols, 0) #, 29)
    
    os.makedirs(data_path, exist_ok=True)
    with open(dataset_fname, 'wb') as f:
        pickle.dump(source_imgs_train, f)
        pickle.dump(teaching_imgs_train, f)
        pickle.dump(source_imgs_test, f)
        pickle.dump(teaching_imgs_test, f)
        pickle.dump(weight, f)

my_dataset_train = MyDatasets(source_imgs_train, teaching_imgs_train)
my_dataset_test = MyDatasets(source_imgs_test, teaching_imgs_test)

loader_train = torch.utils.data.DataLoader(dataset=my_dataset_train, batch_size=mini_batch_size, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset=my_dataset_test, batch_size=mini_batch_size, shuffle=False)

In [None]:
def plot_data(x, y, n=2):
    """データ表示確認"""
    n = min(x.shape[0], n)
    plt.figure(figsize=(18,2*n))
    for i in range(n):
        plt.subplot(n,12,i*12+1)
        plt.title('sorce')
        plt.imshow(x[i][0], 'gray')
        for j in range(11):
            plt.subplot(n,12,i*12+2+j)
            if j < 10:
                plt.title('class=%d'%j)
                plt.imshow(y[i][j], 'gray')
            else:
                plt.title('background')
                plt.imshow(y[i][j], 'gray')
    plt.tight_layout()
    plt.show()

In [None]:
# 学習データ表示確認
if False:
    plt.figure(figsize=(16,8))
    for x, t in loader_test:
        break

    plot_data(x, t, 4)

## ネットワーク定義

In [None]:
def dump_object(obj, file):
    with open(file, 'wb') as f:
        pickle.dump(obj, f)

def load_object(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

In [None]:
def print_numpy_info(x, name=""):
    print('[%s] std:%f, min:%f, max:%f, nan:'%(name, np.std(x), np.min(x), np.max(x)), np.isnan(x).any(), x.shape)

In [None]:
test_x = None

class AffineBlock(bb.Sequential):
    def __init__(self, output_ch, name=''):
        self.affine = bb.DenseAffine([output_ch, 1, 1])
        self.bn     = bb.BatchNormalization()
        super(AffineBlock, self).__init__([self.affine, self.bn], name=name)
    
    def forward(self, x, train=True):
        x = self.affine.forward(x, train)
        y = self.bn.forward(x, train)
        return y
    
    def backward(self, dy):
        dx = self.bn.backward(dy)
        dx = self.affine.backward(dx)
        return dx

def create_dense_conv(output_ch, filter_size=(3, 3), padding='same', last=False, name=''):
    """バイナリ化したDenseConv層生成"""
    return bb.Convolution2d(
#                bb.Sequential([
#                    bb.DifferentiableLut([output_ch*6, 1, 1], name=name+'a'),
#                    bb.DifferentiableLut([output_ch, 1, 1], name=name),
#                ]),
                AffineBlock(output_ch, name=name),
#                bb.Sequential([
#                    bb.DenseAffine([output_ch, 1, 1]),
#                    bb.BatchNormalization(),
#                ]),
                filter_size=filter_size,
                padding=padding)

class ConvBlock(bb.Sequential):
    """基本ブロック"""
    def __init__(self, in_ch=32, out_ch=32, last=False, name=""):
        self.cnv0  = create_dense_conv(out_ch, name=name+".cnv0")
        self.relu0 = bb.ReLU()
        self.cnv1  = create_dense_conv(out_ch, name=name+".cnv1")
        if last:
            self.relu1 = None
            super(ConvBlock, self).__init__([self.cnv0, self.relu0, self.cnv1], name=name)
        else:
            self.relu1 = bb.ReLU()
            super(ConvBlock, self).__init__([self.cnv0, self.relu0, self.cnv1, self.relu1], name=name)
    
    def backward(self, dy):
        if self.relu1 is not None:
            dy = self.relu1.backward(dy)
        dy = self.cnv1.backward(dy)
        dy = self.relu0.backward(dy)
        dy = self.cnv0.backward(dy)
        return dy

class ScaledNetwork(bb.Sequential):
    """スケール階層モデル"""
    def __init__(self, ch=32, top=False):
        self.top = top
        self.ch  = ch
        
        self.up   = bb.UpSampling((2, 2))
        self.pool = bb.MaxPooling((2, 2))
        self.cat0 = bb.Concatenate()
        self.cat1 = bb.Concatenate()
        if self.top:
            self.cnv0 = ConvBlock(self.ch+1, self.ch, name="m_cnv0")
            self.cnv1 = ConvBlock(self.ch,   11, last=self.top, name="m_cnv1")
        else:
            self.cnv0 = ConvBlock(self.ch*2, self.ch, name="s_cnv0")
            self.cnv1 = ConvBlock(self.ch*2, self.ch, name="s_cnv1")
        
        super(ScaledNetwork, self).__init__([self.up, self.pool, self.cnv0, self.cnv1])
        
    def set_input_shape(self, x0_shape, x1_shape, u_shape):
        x1_shape = self.up.set_input_shape(x1_shape)
        x_shape = self.cat0.set_input_shape([x0_shape, x1_shape])
        x_shape = self.cnv0.set_input_shape(x_shape)
        v_shape = self.pool.set_input_shape(x_shape)
        if not self.top:
            x_shape = self.cat1.set_input_shape([u_shape, x_shape])
        y_shape = self.cnv1.set_input_shape(x_shape)
        return y_shape, v_shape
    
    def forward(self, x0, x1, u, train=True):
        x1 = self.up.forward(x1, train=train)
        x = self.cat0.forward([x0, x1])
        x = self.cnv0.forward(x, train=train)        
        v = self.pool.forward(x, train=train)
        if not self.top:
            x = self.cat1.forward([u, x])
        y = self.cnv1.forward(x, train=train)
        return y, v
    
    def backward(self, dy, dv):
        dy = self.cnv1.backward(dy)
        if not self.top:
            du, dy0 = self.cat1.backward([dy])
        else:
            du, dy0 = None, dy
        dy1 = self.pool.backward(dv)
        dy = self.cnv0.backward(dy0 + dy1)
        dx0, dx1 = self.cat0.backward([dy])
        dx1 = self.up.backward(dx1)
        return dx0, dx1, du


class MipmapNetwork(bb.Sequential):
    """ミップマップ型ネット"""
    def __init__(self, loop=3, depth=4, ch=32):
        self.loop    = loop
        self.depth   = depth
        self.ch      = ch
        self.shape   = None
        self.up      = bb.UpSampling((2, 2))
        self.m_net = ScaledNetwork(ch, top=True)
        self.s_net = ScaledNetwork(ch)
        super(MipmapNetwork, self).__init__([self.m_net, self.s_net])
    
    def set_input_shape(self, shape):
        self.shape = copy.copy(shape)
        
        x0_shape = copy.copy(shape)
        x1_shape = copy.copy(shape)
        x1_shape[0] = self.ch
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        u_shape = copy.copy(shape)
        y_shape, v_shape = self.m_net.set_input_shape(x0_shape, x1_shape, u_shape)
        
        x0_shape = copy.copy(x1_shape)
        x0_shape[0] = self.ch
        x1_shape = copy.copy(x1_shape)
        x1_shape[0] = self.ch
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        self.s_net.set_input_shape(x0_shape, x1_shape, v_shape)
        
        return y_shape
    
    def make_mipmap(self, frame_size, dtype=bb.DType.FP32):
        h = self.shape[1]
        w = self.shape[2]
        mipmap = []
        for i in range(self.depth+1):
            h //= 2
            w //= 2
            buf = bb.FrameBuffer(frame_size, (self.ch, h, w), dtype=dtype)
            buf.fill_zero()
            mipmap.append(buf)
        return mipmap
    
    def forward(self, x, train=True):
        self.m_net.clear()
        self.s_net.clear()
        
        frame_size = x.get_frame_size()
        self.shape = x.get_node_shape()
        
        y_list = []
        mipmap = self.make_mipmap(frame_size)        
        for i in range(self.loop):
            y, v = self.m_net.forward(x, mipmap[0], None, train=train)
            if i < self.loop-1:
                for j in range(self.depth):
                    mipmap[j], v = self.s_net.forward(mipmap[j], mipmap[j+1], v)
                self.v_shape = v.get_node_shape()
            y_list.append(y)
        return y_list
    
    def backward(self, dy_list):
        frame_size = dy_list[0].get_frame_size()
        
        dv = bb.FrameBuffer(frame_size, self.v_shape, dtype=bb.DType.FP32)
        dv.fill_zero()
        
        shape = dy_list[0].get_node_shape()
        shape[0] = self.ch
        shape[1] //=2
        shape[2] //=2
        du = bb.FrameBuffer(frame_size, shape, dtype=bb.DType.FP32)
        du.fill_zero()
        
        mipmap = self.make_mipmap(frame_size, dtype=bb.DType.FP32)
        for i in reversed(range(self.loop)):
            if i < self.loop-1:
                du = dv
                for j in reversed(range(self.depth)):
                    dx0, dx1, du = self.s_net.backward(mipmap[j], du)
                    mipmap[j]    = dx0
                    mipmap[j+1] += dx1
            dx0, dx1, du = self.m_net.backward(dy_list[i], du)
            mipmap[0] = dx1
        return dx0

def view(net, loader, n=2):
    """表示確認"""
    for x, t in loader:
        break
    
    x = bb.FrameBuffer.from_numpy(np.array(x).astype(np.float32))
    yy = net.forward(x, train=False)
    y = yy[-1]
    
    x = x.numpy()
    y = y.numpy()
    
    plot_data(x, y, n)

## 学習

In [None]:
net = MipmapNetwork(loop=loops, depth=depth)
net.set_input_shape([1, img_h*rows, img_w*cols])

In [None]:
loss      = bb.LossSoftmaxCrossEntropy()
metrics   = bb.MetricsCategoricalAccuracy()
optimizer = bb.OptimizerAdam(learning_rate=0.00001)

optimizer.set_variables(net.get_parameters(), net.get_gradients())

for epoch in range(epochs):
    # learning
    loss.clear()
    metrics.clear()
    net.clear()
    with tqdm(loader_train) as tqdm_loadr:
        for x, t in tqdm_loadr:

            x = bb.FrameBuffer.from_numpy(np.array(x).astype(np.float32))
            t = bb.FrameBuffer.from_numpy(np.array(t).astype(np.float32))

            y_list = net.forward(x, train=True)
            
            # 複数の出力それぞれ loss 計算
            dy_list = []
            for y in y_list:
                dy = loss.calculate(y, t)
                dy_list.append(dy)
            
            # 最後の一個で精度確認
            metrics.calculate(y_list[-1], t)
            
            # backward
            net.backward(dy_list)
            optimizer.update()
            
            tqdm_loadr.set_postfix(loss=loss.get(), acc=metrics.get())
#   view(net, loader_test, n=2)

In [None]:
view(net, loader_test, n=8)