In [None]:
import os
import shutil
import random
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
#from tqdm import tqdm
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Function

import binarybrain as bb

In [None]:
bb.get_version_string()

In [None]:
use_bb = True
use_torch = False

bin_mode = True
frame_modulation_size = 1

verbose = 0

epochs = 8

loops = 5
depth = 3
ch = 64

# 並べるタイル数
rows=2
cols=2
img_h = 32
img_w = 32

data_path = './data/cmp_torch_mnist_iir_segmentation_bb2'
mini_batch_size = 16

## 学習データ作成

In [None]:
# dataset
transform = transforms.Compose([
                transforms.Resize((img_w, img_h)),
                transforms.ToTensor(),
            ])

dataset_path = './data/'
dataset_train = torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transform, download=True)
dataset_test  = torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transform, download=True)

def make_teacher_image(gen, rows, cols, margin=0):
    source_img  = np.zeros((1, rows*img_h, cols*img_w), dtype=np.float32)
    teaching_img = np.zeros((11, rows*img_h, cols*img_w), dtype=np.float32)
    for row in range(rows):
        for col in range(cols):
            x = col*img_w
            y = row*img_h
            img, label = gen.__next__()
            source_img[0,y:y+img_h,x:x+img_w] = img
            teaching_img[label,y:y+img_h,x:x+img_w] = img
            teaching_img[10,y:y+img_h,x:x+img_w] = (1.0-img)
    teaching_img = (teaching_img > 0.5).astype(np.float32)
    
    # 面積で重みを載せる
    for i in range(11):
        teaching_img[i] *= weight[i]
    
    # ランダムに反転
    if random.random() > 0.5:
        source_img = 1.0 - source_img

    if margin > 0:
        return source_img, teaching_img[:,margin:-margin,margin:-margin]
    return source_img, teaching_img        

def transform_data(dataset, n, rows, cols, margin):
    def data_gen():
        l = len(dataset)
        i = 0
        while True:
            yield dataset[i%l]
            i += 1
    
    gen = data_gen()
    source_imgs = []
    teaching_imgs = []
    for _ in range(n):
        x, t = make_teacher_image(gen, rows, cols, margin)
        source_imgs.append(x)
        teaching_imgs.append(t)
    return source_imgs, teaching_imgs

class MyDatasets(torch.utils.data.Dataset):
    def __init__(self, source_imgs, teaching_imgs, transforms=None):
        self.transforms = transforms
        self.source_imgs = source_imgs
        self.teaching_imgs = teaching_imgs
        
    def __len__(self):
        return len(self.source_imgs)

    def __getitem__(self, index):
        source_img = self.source_imgs[index]
        teaching_img = self.teaching_imgs[index]
        if self.transforms:
            source_img, teaching_img = self.transforms(source_img, teaching_img)
        return source_img, teaching_img

# フィルタ処理用にタイル化する
dataset_fname = os.path.join(data_path, 'dataset.pickle')
#if os.path.exists(dataset_fname):
if False:
    with open(dataset_fname, 'rb') as f:
        source_imgs_train = pickle.load(f)
        teaching_imgs_train = pickle.load(f)
        source_imgs_test = pickle.load(f)
        teaching_imgs_test = pickle.load(f)
        weight = pickle.load(f)
else:
    # 面積の比率で重み作成
    areas = np.zeros((11))
    for img, label in dataset_train:
        img = img.numpy()
        areas[label] += np.mean(img)
        areas[10] += np.mean(1.0-img)
    areas /= len(dataset_train)
    
    weight = 1 / areas
    weight[10] *= 10
    weight /= np.sum(weight)
    
    source_imgs_train, teaching_imgs_train = transform_data(dataset_train, 4096, rows, cols, 0) #29)
    source_imgs_test, teaching_imgs_test = transform_data(dataset_test, 128, rows, cols, 0) #, 29)
    
    os.makedirs(data_path, exist_ok=True)
    with open(dataset_fname, 'wb') as f:
        pickle.dump(source_imgs_train, f)
        pickle.dump(teaching_imgs_train, f)
        pickle.dump(source_imgs_test, f)
        pickle.dump(teaching_imgs_test, f)
        pickle.dump(weight, f)

my_dataset_train = MyDatasets(source_imgs_train, teaching_imgs_train)
my_dataset_test = MyDatasets(source_imgs_test, teaching_imgs_test)

loader_train = torch.utils.data.DataLoader(dataset=my_dataset_train, batch_size=mini_batch_size, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset=my_dataset_test, batch_size=mini_batch_size, shuffle=False)

In [None]:
def plot_data(x, y, n=2):
    """データ表示確認"""
    n = min(x.shape[0], n)
    plt.figure(figsize=(18,2*n))
    for i in range(n):
        plt.subplot(n,12,i*12+1)
        plt.title('sorce')
        plt.imshow(x[i][0], 'gray')
        for j in range(11):
            plt.subplot(n,12,i*12+2+j)
            if j < 10:
                plt.title('class=%d'%j)
                plt.imshow(y[i][j], 'gray')
            else:
                plt.title('background')
                plt.imshow(y[i][j], 'gray')
    plt.tight_layout()
    plt.show()

In [None]:
# 学習データ表示確認
if False:
    plt.figure(figsize=(16,8))
    for x, t in loader_test:
        break

    plot_data(x, t, 4)

In [None]:
def to_numpy(x):
    if type(x) == np.ndarray:
        return x
    if type(x) == bb.FrameBuffer or type(x) == bb.Tensor:
        return x.numpy()
    if type(x) == torch.Tensor:
        return x.to('cpu').detach().clone().numpy()
    return x.to('cpu').detach().clone().numpy()

def calc_corrcoef(a, b):
    a = to_numpy(a).reshape(-1)
    b = to_numpy(b).reshape(-1)
    return np.corrcoef(a, b)[0][1]

def print_summary(x, text=""):
    x = to_numpy(x)
    print("[%s] mean:%1.8f std:%1.8f min:%1.8f max:%1.8f isnan:%d"%(text, np.mean(x), np.std(x), np.min(x), np.max(x), np.isnan(x).any()))

def print_diff(a, b, text=""):
    a = to_numpy(a)
    b = to_numpy(b).reshape(a.shape)
    print_summary(a - b, text=text)

In [None]:
def print_numpy_info(x, name="W_bb"):
    print('[%s] std:%f, min:%f, max:%f, nan:'%(name, np.std(x), np.min(x), np.max(x)), np.isnan(x).any(), x.shape)

In [None]:
def dump_object(obj, file):
    with open(file, 'wb') as f:
        pickle.dump(obj, f)

def load_object(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

## ネットワーク定義

In [None]:
# バイナリ時は BIT型を使えばメモリ削減可能
#bin_dtype = bb.DType.BIT if bin_mode else bb.DType.FP32
bin_dtype = bb.DType.FP32

class Convolution(bb.Sequential):
    def __init__(self, in_ch, out_ch, batch_norm=False, activation=True, name=""):
        self.batch_norm = batch_norm
        self.activation = activation
        
        self.blk = bb.BinaryDenseAffine([out_ch, 1, 1], name=name+'_blk', batch_norm=batch_norm, activation=activation, bin_dtype=bin_dtype)
        self.cnv    = bb.Convolution2d(
                                self.blk,
                                filter_size=(3, 3),
                                padding='same',
                                name=name+'_cnv',
                                fw_dtype=bin_dtype)
        super(Convolution, self).__init__([self.cnv], name=name)

class ConvBlock(bb.Sequential):
    """基本ブロック"""
    def __init__(self, in_ch=32, hid_ch=32, out_ch=32, name=""):
        self.cnv0  = Convolution(in_ch, hid_ch, name=name+"_cnv0")
        self.cnv1  = Convolution(hid_ch, out_ch, name=name+"_cnv1")
        super(ConvBlock, self).__init__([self.cnv0, self.cnv1], name=name)
    
    def parameters(self):
        return self.cnv0.parameters() + self.cnv1.parameters()

    def set_input_shape(self, shape):
        shape = self.cnv0.set_input_shape(shape)
        shape = self.cnv1.set_input_shape(shape)
        return shape
    
    def forward(self, x, train=True):
        x = self.cnv0.forward(x, train=train)
        x = self.cnv1.forward(x, train=train)
        return x
    
    def backward(self, dy):
        dy = self.cnv1.backward(dy)
        dy = self.cnv0.backward(dy)
        return dy

class ScaledNetwork(bb.Sequential):
    """スケール階層モデル"""
    def __init__(self, ch=32, top=False):
        self.top = top
        self.ch  = ch
        
        self.up   = bb.UpSampling((2, 2), fw_dtype=bin_dtype)
        self.pool = bb.MaxPooling((2, 2), fw_dtype=bin_dtype)
        self.cat0 = bb.Concatenate()
        self.cat1 = bb.Concatenate()
        if self.top:
            self.cnv0 = ConvBlock(1+self.ch, self.ch, self.ch, name="m_blk0")
            self.cnv1 = ConvBlock(self.ch,   self.ch, 11,      name="m_blk1")
        else:
            self.cnv0 = ConvBlock(self.ch*2, self.ch, self.ch, name="s_blk0")
            self.cnv1 = ConvBlock(self.ch*2, self.ch, self.ch, name="s_blk1")
        
        super(ScaledNetwork, self).__init__([self.up, self.pool, self.cnv0, self.cnv1])

    def parameters(self):
        return self.cnv0.parameters() + self.cnv1.parameters()
    
    def set_input_shape(self, x0_shape, x1_shape, u_shape):
        x1_shape = self.up.set_input_shape(x1_shape)        
        x_shape = self.cat0.set_input_shape([x0_shape, x1_shape])
        x_shape = self.cnv0.set_input_shape(x_shape)
        v_shape = self.pool.set_input_shape(x_shape)
        if not self.top:
            x_shape = self.cat1.set_input_shape([u_shape, x_shape])
        y_shape = self.cnv1.set_input_shape(x_shape)
        return y_shape, v_shape
    
    def forward(self, x0, x1, u, train=True):
        x1 = self.up.forward(x1, train=train)
        x  = self.cat0.forward([x0, x1], train=train)
        x  = self.cnv0.forward(x, train=train)        
        v = self.pool.forward(x, train=train)
        if not self.top:
            x = self.cat1.forward([u, x], train=train)
        y = self.cnv1.forward(x, train=train)
        return y, v
    
    def backward(self, dy, dv):
        dy = self.cnv1.backward(dy)
        if not self.top:
            du, dy0 = self.cat1.backward([dy])
        else:
            du, dy0 = None, dy
        dy1 = self.pool.backward(dv)
        dy = self.cnv0.backward(dy0 + dy1)
        dx0, dx1 = self.cat0.backward([dy])
        dx1 = self.up.backward(dx1)
        return dx0, dx1, du


class MipmapNetwork(bb.Sequential):
    """ミップマップ型ネット"""
    def __init__(self, loop=3, depth=4, ch=32):
        self.loop    = loop
        self.depth   = depth
        self.ch      = ch
        self.shape   = None
        self.r2b     = bb.RealToBinary(frame_modulation_size=frame_modulation_size, bin_dtype=bin_dtype)
        self.b2r     = bb.BinaryToReal(frame_integration_size=frame_modulation_size, bin_dtype=bin_dtype)
        self.up      = bb.UpSampling((2, 2), fw_dtype=bin_dtype) # UpSampling()
        self.m_net = ScaledNetwork(ch, top=True)
        self.s_net = ScaledNetwork(ch)
        super(MipmapNetwork, self).__init__([self.m_net, self.s_net])

    def parameters(self):
        return self.m_net.parameters() + self.s_net.parameters()

    def set_input_shape(self, shape):
        self.shape = copy.copy(shape)
        shape = self.r2b.set_input_shape(shape)
        
        x0_shape = copy.copy(shape)
        x1_shape = copy.copy(shape)
        x1_shape[0] = self.ch
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        u_shape = copy.copy(shape)
        y_shape, v_shape = self.m_net.set_input_shape(x0_shape, x1_shape, u_shape)
        
        x0_shape = copy.copy(x1_shape)
        x0_shape[0] = self.ch
        x1_shape = copy.copy(x1_shape)
        x1_shape[0] = self.ch
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        self.s_net.set_input_shape(x0_shape, x1_shape, v_shape)
        
        y_shape = self.b2r.set_input_shape(y_shape)
        return y_shape
    
    def make_mipmap(self, n, dtype=bin_dtype):
        h = self.shape[1]
        w = self.shape[2]
        mipmap = []
        for i in range(self.depth+1):
            h //= 2
            w //= 2
            buf = bb.FrameBuffer.zeros(n, (self.ch, h, w))
            mipmap.append(buf)
        return mipmap
    
    def forward(self, x, train=True):
        self.m_net.clear()
        self.s_net.clear()
        
#       x = self.r2b.forward(x)
        
        frame_size = x.get_frame_size()
        self.shape = x.get_node_shape()
        
        y_list = []
        mipmap = self.make_mipmap(frame_size, dtype=bin_dtype)        
        for i in range(self.loop):
            y, v = self.m_net.forward(x, mipmap[0], None, train=train)
            self.u_shape = v.get_node_shape()
            if i < self.loop-1:
                for j in range(self.depth):
                    mipmap[j], v = self.s_net.forward(mipmap[j], mipmap[j+1], v)
                self.v_shape = v.get_node_shape()
#           y = self.b2r.forward(y)
            y_list.append(y)
        return y_list
    
    def backward(self, dy_list):
        frame_size = dy_list[0].get_frame_size() # * frame_modulation_size
        
        dv = bb.FrameBuffer.zeros(frame_size, self.v_shape, dtype=bb.DType.FP32)
        du = bb.FrameBuffer.zeros(frame_size, self.u_shape, dtype=bb.DType.FP32)
        
        mipmap = self.make_mipmap(frame_size, dtype=bb.DType.FP32)
        for i in reversed(range(self.loop)):
            if i < self.loop-1:
                du = dv
                for j in reversed(range(self.depth)):
                    dx0, dx1, du = self.s_net.backward(mipmap[j], du)
                    mipmap[j]    = dx0
                    mipmap[j+1] += dx1
            dy = dy_list[i]  
          # dy = self.b2r.backward(dy_list[i])
            dx0, dx1, du = self.m_net.backward(dy, du)
            mipmap[0] = dx1
        return dx0

def view(net, loader, n=2):
    """表示確認"""
    for x, t in loader:
        break
    
    x = bb.FrameBuffer.from_numpy(np.array(x).astype(np.float32))
    yy = net.forward(x, train=False)
    y = yy[-1]
    
    x = x.numpy()
    y = y.numpy()
    
    plot_data(x, y, n)

## 学習

In [None]:
net = MipmapNetwork(loop=loops, depth=depth)
net.set_input_shape([1, img_h*rows, img_w*cols])
net.send_command("binary true")

bb.load_networks(data_path, net)

In [None]:
criterion = bb.LossSoftmaxCrossEntropy()
metrics   = bb.MetricsCategoricalAccuracy()
optimizer = bb.OptimizerAdam(learning_rate=0.0001)
optimizer.set_variables(net.get_parameters(), net.get_gradients())

criterion.clear()
metrics.clear()
net.clear()

log_loss = []

epochs=64
for epoch in range(epochs):
    # learning
    criterion.clear()
    metrics.clear()
    
    with tqdm(loader_train) as tqdm_loadr:
        for x_torch, t_torch in tqdm_loadr:
            net.clear()
            
            x_torch[x_torch >  0.5] = +1
            x_torch[x_torch <= 0.5] = -1
            
            x = bb.FrameBuffer.from_numpy(np.array(x_torch).astype(np.float32)).astype(bin_dtype)
            t = bb.FrameBuffer.from_numpy(np.array(t_torch).astype(np.float32))
            
            y_list = net.forward(x, train=True)
            
            dy_list = []
            for i, y in enumerate(y_list):
                if i < loops-2:
                    dy = bb.FrameBuffer.zeros(t.get_frame_size(), t.get_node_shape())
                else:
                    dy = criterion.calculate(y, t)
                dy_list.append(dy)
            metrics.calculate(y_list[-1], t)
            
            net.backward(dy_list)
            optimizer.update()
            
            
            loss_val = criterion.get()
#           criterion.clear()
            
            log_loss.append(loss_val)
            
            tqdm_loadr.set_postfix(loss=loss_val, accuracy=metrics.get())
    
    bb.save_networks(data_path, net, backups=3)
    if epoch % 10 == 9:
        view(net, loader_test, n=4)

In [None]:
for x_torch, t_torch in tqdm_loadr:
    x    = np.array(x_torch).astype(np.float32)
#   t    = np.array(t_torch).astype(np.float32)
    x_buf  = bb.FrameBuffer.from_numpy(np.array(x_torch).astype(np.float32))
    y_list = net.forward(x_buf, train=False)
    y      = y_list[-1].numpy()

plot_data(x, y, 8)

In [None]:
view(net, loader_test, n=4)

In [None]:
-------

In [None]:
bb.save_networks(data_path, net, name="hoge")

In [None]:
bb.save_networks(data_path, net, name="layers", write_layers=True)