In [None]:
import os
import shutil
import random
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
#from tqdm import tqdm
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Function

import binarybrain as bb

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
print(device)

In [None]:
bb.get_version_string()

In [None]:
use_bb = True
use_torch = True

bin_mode = True
bin_modulation = False
frame_modulation_size = 1

verbose = 0

epochs = 8

loops = 5
depth = 3
ch = 32

# 並べるタイル数
rows=2
cols=2
img_h = 32
img_w = 32

data_path = './data/cmp_torch_mnist_iir_segmentation_v2'
mini_batch_size = 16

## 学習データ作成

In [None]:
# dataset
transform = transforms.Compose([
                transforms.Resize((img_w, img_h)),
                transforms.ToTensor(),
            ])

dataset_path = './data/'
dataset_train = torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transform, download=True)
dataset_test  = torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transform, download=True)

def make_teacher_image(gen, rows, cols, margin=0):
    source_img  = np.zeros((1, rows*img_h, cols*img_w), dtype=np.float32)
    teaching_img = np.zeros((11, rows*img_h, cols*img_w), dtype=np.float32)
    for row in range(rows):
        for col in range(cols):
            x = col*img_w
            y = row*img_h
            img, label = gen.__next__()
            source_img[0,y:y+img_h,x:x+img_w] = img
            teaching_img[label,y:y+img_h,x:x+img_w] = img
            teaching_img[10,y:y+img_h,x:x+img_w] = (1.0-img)
    teaching_img = (teaching_img > 0.5).astype(np.float32)
    
    # 面積で重みを載せる
    for i in range(11):
        teaching_img[i] *= weight[i]
    
    # ランダムに反転
    if random.random() > 0.5:
        source_img = 1.0 - source_img

    if margin > 0:
        return source_img, teaching_img[:,margin:-margin,margin:-margin]
    return source_img, teaching_img        

def transform_data(dataset, n, rows, cols, margin):
    def data_gen():
        l = len(dataset)
        i = 0
        while True:
            yield dataset[i%l]
            i += 1
    
    gen = data_gen()
    source_imgs = []
    teaching_imgs = []
    for _ in range(n):
        x, t = make_teacher_image(gen, rows, cols, margin)
        source_imgs.append(x)
        teaching_imgs.append(t)
    return source_imgs, teaching_imgs

class MyDatasets(torch.utils.data.Dataset):
    def __init__(self, source_imgs, teaching_imgs, transforms=None):
        self.transforms = transforms
        self.source_imgs = source_imgs
        self.teaching_imgs = teaching_imgs
        
    def __len__(self):
        return len(self.source_imgs)

    def __getitem__(self, index):
        source_img = self.source_imgs[index]
        teaching_img = self.teaching_imgs[index]
        if self.transforms:
            source_img, teaching_img = self.transforms(source_img, teaching_img)
        return source_img, teaching_img

# フィルタ処理用にタイル化する
dataset_fname = os.path.join(data_path, 'dataset.pickle')
if os.path.exists(dataset_fname):
#if False:
    with open(dataset_fname, 'rb') as f:
        source_imgs_train = pickle.load(f)
        teaching_imgs_train = pickle.load(f)
        source_imgs_test = pickle.load(f)
        teaching_imgs_test = pickle.load(f)
        weight = pickle.load(f)
else:
    # 面積の比率で重み作成
    areas = np.zeros((11))
    for img, label in dataset_train:
        img = img.numpy()
        areas[label] += np.mean(img)
        areas[10] += np.mean(1.0-img)
    areas /= len(dataset_train)
    
    weight = 1 / areas
    weight /= np.sum(weight)
    
    source_imgs_train, teaching_imgs_train = transform_data(dataset_train, 4096, rows, cols, 0) #29)
    source_imgs_test, teaching_imgs_test = transform_data(dataset_test, 128, rows, cols, 0) #, 29)
    
    os.makedirs(data_path, exist_ok=True)
    with open(dataset_fname, 'wb') as f:
        pickle.dump(source_imgs_train, f)
        pickle.dump(teaching_imgs_train, f)
        pickle.dump(source_imgs_test, f)
        pickle.dump(teaching_imgs_test, f)
        pickle.dump(weight, f)

my_dataset_train = MyDatasets(source_imgs_train, teaching_imgs_train)
my_dataset_test = MyDatasets(source_imgs_test, teaching_imgs_test)

loader_train = torch.utils.data.DataLoader(dataset=my_dataset_train, batch_size=mini_batch_size, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset=my_dataset_test, batch_size=mini_batch_size, shuffle=False)

In [None]:
def plot_data(x, y, n=2):
    """データ表示確認"""
    n = min(x.shape[0], n)
    plt.figure(figsize=(18,2*n))
    for i in range(n):
        plt.subplot(n,12,i*12+1)
        plt.title('sorce')
        plt.imshow(x[i][0], 'gray')
        for j in range(11):
            plt.subplot(n,12,i*12+2+j)
            if j < 10:
                plt.title('class=%d'%j)
                plt.imshow(y[i][j], 'gray')
            else:
                plt.title('background')
                plt.imshow(y[i][j], 'gray')
    plt.tight_layout()
    plt.show()

In [None]:
# 学習データ表示確認
if False:
    plt.figure(figsize=(16,8))
    for x, t in loader_test:
        break

    plot_data(x, t, 4)

In [None]:
def to_numpy(x):
    if type(x) == np.ndarray:
        return x
    if type(x) == bb.FrameBuffer or type(x) == bb.Tensor:
        return x.numpy()
    if type(x) == torch.Tensor:
        return x.to('cpu').detach().clone().numpy()
    return x.to('cpu').detach().clone().numpy()

def calc_corrcoef(a, b):
    a = to_numpy(a).reshape(-1)
    b = to_numpy(b).reshape(-1)
    return np.corrcoef(a, b)[0][1]

def print_summary(x, text=""):
    if x is None:
        print("[%s] None"%(text))
        return
    x = to_numpy(x)
    print("[%s] mean:%1.8f std:%1.8f min:%1.8f max:%1.8f isnan:%d"%(text, np.mean(x), np.std(x), np.min(x), np.max(x), np.isnan(x).any()))

def print_diff(a, b, text=""):
    a = to_numpy(a)
    b = to_numpy(b).reshape(a.shape)
    print_summary(a - b, text=text)

def print_diff_summary(a_bb, a_torch, text=""):
    a_bb    = to_numpy(a_bb)
    a_torch = to_numpy(a_torch).reshape(a_bb.shape)
    r = calc_corrcoef(a_bb, a_torch)
    print('[corrcoef:%f]'%r)
    print_summary(a_bb,           text=text+" bb   ")
    print_summary(a_torch,        text=text+" torch")
    print_summary(a_bb - a_torch, text=text+" diff")

def print_param_status(param, name=""):
    print('[%s] mean:%f std:%f min:%f max:%f nan:'%(name, np.nanmean(param), np.nanstd(param), np.nanmin(param), np.nanmax(param)), np.isnan(param).any())

def print_param_affine_bb(affine_bb, name=""):
    print_param_status(affine_bb.W().numpy(), name+'.W')
    print_param_status(affine_bb.b().numpy(), name+'.b')
    print_param_status(affine_bb.dW().numpy(), name+'.dW')
    print_param_status(affine_bb.db().numpy(), name+'.db')

def print_shape(x, text=""):
    if x is None:
        print("[%s] None"%text)
    else:
        print("[%s]"%text, x.get_shape())

In [None]:
def copy_affine_param(model_bb, model_torch):
    model_bb.W().set_numpy(to_numpy(model_torch.weight).reshape(model_bb.W().get_shape()))
    model_bb.b().set_numpy(to_numpy(model_torch.bias).reshape(model_bb.b().get_shape()))

def plot_affine_param_hist(model_bb, model_torch):
    W_bb    = to_numpy(model_bb.W()).reshape(-1)
    W_torch = to_numpy(model_torch.weight).reshape(-1)
    b_bb    = to_numpy(model_bb.b()).reshape(-1)
    b_torch = to_numpy(model_torch.bias).reshape(-1)
    print_summary(W_bb,    "W_bb   ")
    print_summary(W_torch, "W_torch")
    print_summary(b_bb,    "b_bb   ")
    print_summary(b_torch, "b_torch")
    plt.subplot(221)
    plt.title("W_bb")
    plt.hist(W_bb)
    plt.subplot(222)
    plt.title("W_torch")
    plt.hist(W_torch)
    plt.subplot(223)
    plt.title("b_bb")
    plt.hist(b_bb)
    plt.subplot(224)
    plt.title("b_torch")
    plt.hist(b_torch)
    plt.show()

In [None]:
def print_numpy_info(x, name="W_bb"):
    print('[%s] std:%f, min:%f, max:%f, nan:'%(name, np.std(x), np.min(x), np.max(x)), np.isnan(x).any(), x.shape)

In [None]:
def dump_object(obj, file):
    with open(file, 'wb') as f:
        pickle.dump(obj, f)

def load_object(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

## ネットワーク定義

In [None]:
from torch.autograd import Function
class Binarize(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        y = x.new(x.size())
        y[x >  0] = +1.0
        y[x <= 0] = -1.0
        return y

    @staticmethod
    def backward(ctx, dy):
        x, = ctx.saved_tensors
        dx = dy.clone()
        dx[x.ge(1)]=0
        dx[x.le(-1)]=0
        return dx
binarize = Binarize.apply

class Through(Function):
    @staticmethod
    def forward(ctx, x):
        y = x.clone()
#        print("[torch] forward\n", end='')
#        print_summary(x)
        return y
    
    @staticmethod
    def backward(ctx, dy):
#        print("[torch] backward", end='')
#        print_summary(dy)
        dx = dy.clone()
        return dx

through = Through.apply

In [None]:
# バイナリ時は BIT型を使えばメモリ削減可能
#bin_dtype = bb.DType.BIT if bin_mode else bb.DType.FP32
bin_dtype = bb.DType.FP32

class Buffer():
    """PyTorch/BinaryBrain共用バッファ"""
    def __init__(self, shape=None):
        self.torch = None
        self.bb    = None
        if shape is not None:
            if use_torch:
                self.torch = torch.zeros(shape).to(device)
            if use_bb:
                self.bb = bb.FrameBuffer(shape[0], shape[1:])
    
    def get_shape(self):
        if use_torch and self.torch is not None:
            return list(self.torch.shape)
        if use_bb and self.bb is not None:
            return [self.bb.get_frame_size()] + self.bb.get_node_shape()
        return None
    
    @staticmethod
    def zeros(shape):
        x = Buffer()
        if use_torch:
            x.torch = torch.zeros(shape).to(device)
        if use_bb:
            x.bb = bb.FrameBuffer.zeros(shape[0], shape[1:])
        return x
    
    @staticmethod
    def zeros_like(x):
        return Buffer.zeros(x.get_shape())

In [None]:
class AffineBlock(bb.Sequential):
    def __init__(self, out_ch, batch_norm, activation, name=""):
        self.batch_norm = batch_norm
        self.activation = activation
        self.affine = bb.DenseAffine([out_ch, 1, 1], name=name+'_affine', initializer="")
        self.bn     = bb.BatchNormalization(name=name+'_bn')
        self.act    = bb.Binarize(name=name+'_act', bin_dtype=bin_dtype)
        layers = [self.affine]
        if batch_norm: layers.append(self.bn)
        if activation: layers.append(self.act)
        super(AffineBlock, self).__init__(layers, name=name)
    
    def forward(self, x, train=True):
        x = self.affine.forward(x, train=True)
        if self.batch_norm:
            x = self.bn.forward(x, train=True)
        if self.activation:
            x = self.act.forward(x, train=True)
        return x
    
    def backward(self, dy):
        if self.activation:
            dy = self.act.backward(dy)
        if self.batch_norm:
            dy = self.bn.backward(dy)
        dy = self.affine.backward(dy)
        if verbose > 2:
            print_summary(dy)
        return dy


class Convolution(bb.Sequential):
    def __init__(self, in_ch, out_ch, batch_norm=True, activation=True, name=""):
        self.batch_norm = batch_norm
        self.activation = activation
        
        if use_torch:
            self.cnv_torch = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='reflect').to(device) # 'replicate'
            self.bn_torch  = nn.BatchNorm2d(out_ch, eps=1e-07).to(device)
        
        if use_bb:
            self.blk_bb = AffineBlock(out_ch, name=name+'_blk', batch_norm=batch_norm, activation=activation)
#           self.blk_bb = bb.BinaryDenseAffine([out_ch, 1, 1], name=name+'_blk', batch_norm=batch_norm, activation=activation, bin_dtype=bin_dtype)
            self.cnv_bb    = bb.Convolution2d(
                                    self.blk_bb,
                                    filter_size=(3, 3),
                                    padding='same',
                                    name=name+'_cnv',
                                    fw_dtype=bin_dtype)
            super(Convolution, self).__init__([self.cnv_bb], name=name)
        else:
            super(Convolution, self).__init__([], name=name)
    
    def parameters(self):
        if self.batch_norm:
            return list(self.cnv_torch.parameters()) + list(self.bn_torch.parameters())
        else:
            return list(self.cnv_torch.parameters())
    
    def set_input_shape(self, shape):
        shape = self.cnv_bb.set_input_shape(shape)
#        plot_affine_param_hist(self.blk_bb, self.cnv_torch)
##       copy_affine_param(self.blk_bb.affine, self.cnv_torch)
#        copy_affine_param(self.blk_bb, self.cnv_torch)
        return shape
    
    def forward(self, x, train=True):
        if verbose > 2:
            print("[%s] forward"%self.name)
            print_summary(x.torch, text='in torch')
            print_summary(x.bb,    text='in bb   ')
        
        y = Buffer()
        if use_torch:
            y.torch = self.cnv_torch(x.torch)
            if self.batch_norm:
                y.torch = self.bn_torch(y.torch)
            if self.activation:
                y.torch = binarize(y.torch)
            y.torch = through(y.torch)
            
        if use_bb:
            y.bb = self.cnv_bb.forward(x.bb, train=train)
        
        if verbose > 2:
            print_summary(y.torch, text='out torch')
            print_summary(y.bb,    text='out bb   ')
        return y
    
    def backward(self, dy):
        if use_bb:
#           print("[bb] backward %s"%self.name)
            dy = self.cnv_bb.backward(dy)
        return dy

class UpSampling(bb.Sequential):
    def __init__(self, name=""):
        if use_torch:
            self.up_torch = nn.Upsample(scale_factor=2, mode='nearest').to(device)
        
        if use_bb:
            self.up_bb    = bb.UpSampling((2, 2), fw_dtype=bin_dtype)
            super(UpSampling, self).__init__([self.up_bb], name=name)
        else:
            super(UpSampling, self).__init__([], name=name)
    
    def parameters(self):
        return list()
        
    def forward(self, x, train=True):
        y = Buffer()
        if use_torch:
            y.torch = self.up_torch(x.torch)
        if use_bb:
            y.bb = self.up_bb.forward(x.bb, train=train)
        return y

class MaxPooling(bb.Sequential):
    def __init__(self, name=""):
        self.pool_torch = nn.MaxPool2d(2, 2).to(device)
        self.pool_bb    = bb.MaxPooling((2, 2), fw_dtype=bin_dtype)
        super(MaxPooling, self).__init__([self.pool_bb], name=name)
    
    def parameters(self):
        return list()
    
    def forward(self, x, train=True):
        y = Buffer()
        if use_torch:
            y.torch = self.pool_torch(x.torch)
        if use_bb:
            y.bb = self.pool_bb.forward(x.bb, train=train)
        return y

class Concatenate(bb.Sequential):
    def __init__(self, name=""):
        self.cat_bb = bb.Concatenate()
        super(Concatenate, self).__init__([self.cat_bb], name=name)

    def parameters(self):
        return list()

    def set_input_shape(self, shape):
        return self.cat_bb.set_input_shape(shape)
    
    def forward(self, x, train=True):
        y = Buffer()
        if use_torch:
            y.torch = torch.cat([x[0].torch, x[1].torch], 1)
        if use_bb:
            y.bb    = self.cat_bb.forward([x[0].bb, x[1].bb], train=train)
        return y
    
    def backward(self, dy):
        dy = self.cat_bb.backward(dy)
        return dy


class ConvBlock(bb.Sequential):
    """基本ブロック"""
    def __init__(self, in_ch=32, hid_ch=32, out_ch=32, name=""):
        self.cnv0  = Convolution(in_ch, hid_ch, name=name+"_cnv0")
        self.cnv1  = Convolution(hid_ch, out_ch, name=name+"_cnv1")
        super(ConvBlock, self).__init__([self.cnv0, self.cnv1], name=name)
    
    def parameters(self):
        return self.cnv0.parameters() + self.cnv1.parameters()
    
    def set_input_shape(self, shape):
        shape = self.cnv0.set_input_shape(shape)
        shape = self.cnv1.set_input_shape(shape)
        return shape
    
    def forward(self, x, train=True):
        x = self.cnv0.forward(x, train=train)
        x = self.cnv1.forward(x, train=train)
        return x
    
    def backward(self, dy):
        dy = self.cnv1.backward(dy)
        dy = self.cnv0.backward(dy)
        return dy

### サブブロック

In [None]:
class MainBlock(bb.Sequential):
    """最上位(1/1スケール)階層モデル"""
    def __init__(self, ch=32, top=False):
        self.ch   = ch
        self.up   = UpSampling()
        self.cat  = Concatenate()
        self.cnv0 = ConvBlock(1+self.ch, self.ch, self.ch, name="m_cnv0")
        self.cnv1 = ConvBlock(self.ch,   self.ch, 11, name="m_cnv1")
        self.pool = MaxPooling()
        super(MainBlock, self).__init__([self.up, self.cnv0, self.cnv1, self.pool])
    
    def parameters(self):
        return self.cnv0.parameters() + self.cnv1.parameters()
    
    def set_input_shape(self, x0_shape, x1_shape):
        x1_shape = self.up.set_input_shape(x1_shape)
        x_shape = self.cat.set_input_shape([x0_shape, x1_shape])
        x_shape = self.cnv0.set_input_shape(x_shape)
        y_shape = self.cnv1.set_input_shape(x_shape)
        v_shape = self.pool.set_input_shape(x_shape)
        return y_shape, v_shape
    
    def forward(self, x0, x1, first, last, train=True):
        if first:
            shape = x0.get_shape()
            shape[1] = self.ch
            x1 = Buffer.zeros(shape)
        else:
            x1 = self.up.forward(x1, train=train)
        
        x = self.cat.forward([x0, x1], train=train)
        x = self.cnv0.forward(x, train=train)
        y = self.cnv1.forward(x, train=train)
        if last:
            v = None
        else:
            v = self.pool.forward(x, train=train)
        return y, v
    
    def backward(self, dy, dv, first, last):
        dy = self.cnv1.backward(dy)
        if last:
            dx = self.cnv0.backward(dy)
        else:
            dv = self.pool.backward(dv)
            dx = self.cnv0.backward(dy + dv)
        dx0, dx1 = self.cat.backward([dx])
        if first:
            dx1 = None
        else:
            dx1 = self.up.backward(dx1)
        return dx0, dx1


class ScaleBlock(bb.Sequential):
    """下位層(1/4スケール以下)階層モデル"""
    def __init__(self, ch=32, top=False):
        self.ch   = ch
        self.up   = UpSampling()
        self.cat0 = Concatenate()
        self.cat1 = Concatenate()
        self.cnv  = ConvBlock(3*self.ch, self.ch, self.ch, name="s_cnv")
        self.pool = MaxPooling()
        super(ScaleBlock, self).__init__([self.up, self.cnv, self.pool])
    
    def parameters(self):
        return self.cnv.parameters()
    
    def set_input_shape(self, x0_shape, x1_shape, u_shape):
        x1_shape = self.up.set_input_shape(x1_shape)
        x_shape = self.cat0.set_input_shape([x0_shape, x1_shape])
        x_shape = self.cat1.set_input_shape([x_shape, u_shape])
        y_shape = self.cnv.set_input_shape(x_shape)
        v_shape = self.pool.set_input_shape(y_shape)
        return y_shape, v_shape
    
    def forward(self, x0, x1, u, bottom, first, last, train=True):
        if first:
            x0 = Buffer.zeros_like(u)
        if first or bottom:
            x1 = Buffer.zeros_like(u)
        else:
            x1 = self.up.forward(x1, train=train)
        
        x = self.cat0.forward([x0, x1], train=train)
        x = self.cat1.forward([x, u], train=train)
        y = self.cnv.forward(x, train=train)
        if last:
            v = None
        else:
            v = self.pool.forward(y, train=train)
        return y, v
       
    def backward(self, dy, dv, bottom, first, last):
        if last:
            dx = self.cnv.backward(dy)
        else:
            dv = self.pool.backward(dv)
            dx = self.cnv.backward(dy + dv)
        dx, du = self.cat1.backward([dx])
        dx0, dx1 = self.cat0.backward([dx])
        if first:
            dx0 = None
        if first or bottom:
            dx1 = None
        else:
            dx1 = self.up.backward(dx1)
        return dx0, dx1, du

### Mipmap型ネット

In [None]:
class MipmapNetwork(bb.Sequential):
    """ミップマップ型ネット"""
    def __init__(self, loops=2, depth=4, ch=32):
        self.loops = loops
        self.depth = min(depth, loops-2)
        self.ch    = ch
        self.r2b   = bb.RealToBinary(frame_modulation_size=frame_modulation_size, bin_dtype=bin_dtype)
        self.b2r   = bb.BinaryToReal(frame_integration_size=frame_modulation_size, bin_dtype=bin_dtype)
        self.up    = UpSampling()
        self.m_blk = MainBlock(ch)
        self.s_blk = ScaleBlock(ch)
        super(MipmapNetwork, self).__init__([self.m_blk, self.s_blk])
    
    def parameters(self):
        return self.m_blk.parameters() + self.s_blk.parameters()
    
    def set_input_shape(self, shape):
        self.shape = copy.copy(shape)
        shape = self.r2b.set_input_shape(shape)
        
        x0_shape = copy.copy(shape)
        x1_shape = copy.copy(shape)
        x1_shape[0] = self.ch
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        y_shape, v_shape = self.m_blk.set_input_shape(x0_shape, x1_shape)
        
        x0_shape = copy.copy(v_shape)
        x0_shape[0] = self.ch
        x1_shape = copy.copy(v_shape)
        x1_shape[0] = self.ch
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        self.s_blk.set_input_shape(x0_shape, x1_shape, v_shape)
        
        y_shape = self.b2r.set_input_shape(y_shape)
        return y_shape
    
    def forward(self, x, train=True):
        if use_bb and bin_modulation:
            x.bb = self.r2b.forward(x.bb)
        
        mipmap = [None]*(self.depth+1)        
        y_list = []
        for i in range(self.loops):
            depth = min(self.depth, self.loops-1 - i)
            y, v = self.m_blk.forward(x, mipmap[0], first=(i<=1), last=(i==self.loops-1), train=train)
            if i >= 1:
                for j in range(depth):
                    bottom = (j==self.depth-1)
                    first  = (i==1)
                    last   = (j==(depth-1))
                    mipmap[j], v = self.s_blk.forward(mipmap[j], mipmap[j+1], v, bottom, first, last, train=train)
            if use_bb and bin_modulation:
                y.bb = self.b2r.forward(y.bb)
            y_list.append(y)
        return y_list
    
    def backward(self, dy_list):
        du = None
        mipmap = [None]*(self.depth+1)        
        for i in reversed(range(self.loops)):
            depth = min(self.depth, self.loops-1 - i)
            if i >= 1:
                for j in reversed(range(depth)):
                    bottom = (j==self.depth-1)
                    first  = (i==1)
                    last   = (j==(depth-1))
#                    print(mipmap[j], du)
#                    print('hoge:', i, j, depth, bottom, first, last)
                    dx0, dx1, du = self.s_blk.backward(mipmap[j], du, bottom, first, last)
                    if last:
                        mipmap[j] = dx0
                        if not bottom:
                            mipmap[j+1] = dx1
                    elif not first:
                        mipmap[j] += dx0
                        mipmap[j+1] += dx1
            dy = dy_list[i]
            if use_bb and bin_modulation:
                dy = self.b2r.backward(dy_list[i])
            dx0, dx1 = self.m_blk.backward(dy, du, first=(i<=1), last=(i==self.loops-1))
            mipmap[0] = dx1
        return dx0

def view(net, loader, n=2):
    """表示確認"""
    for x, t in loader:
        break
    
    x = bb.FrameBuffer.from_numpy(np.array(x).astype(np.float32))
    yy = net.forward(x, train=False)
    y = yy[-1]
    
    x = x.numpy()
    y = y.numpy()
    
    plot_data(x, y, n)

## 学習

In [None]:
def print_param():
    print_diff_summary(net.m_net.cnv0.cnv0.blk_bb.affine.W(), net.m_net.cnv0.cnv0.cnv_torch.weight)
    print_diff_summary(net.m_net.cnv0.cnv0.blk_bb.affine.b(), net.m_net.cnv0.cnv0.cnv_torch.bias)
    print_diff_summary(net.m_net.cnv0.cnv1.blk_bb.affine.W(), net.m_net.cnv0.cnv1.cnv_torch.weight)
    print_diff_summary(net.m_net.cnv0.cnv1.blk_bb.affine.b(), net.m_net.cnv0.cnv1.cnv_torch.bias)
    print_diff_summary(net.m_net.cnv1.cnv0.blk_bb.affine.W(), net.m_net.cnv1.cnv0.cnv_torch.weight)
    print_diff_summary(net.m_net.cnv1.cnv0.blk_bb.affine.b(), net.m_net.cnv1.cnv0.cnv_torch.bias)
    print_diff_summary(net.m_net.cnv1.cnv1.blk_bb.affine.W(), net.m_net.cnv1.cnv1.cnv_torch.weight)
    print_diff_summary(net.m_net.cnv1.cnv1.blk_bb.affine.b(), net.m_net.cnv1.cnv1.cnv_torch.bias)

    print_diff_summary(net.s_net.cnv0.cnv0.blk_bb.affine.W(), net.s_net.cnv0.cnv0.cnv_torch.weight)
    print_diff_summary(net.s_net.cnv0.cnv0.blk_bb.affine.b(), net.s_net.cnv0.cnv0.cnv_torch.bias)
    print_diff_summary(net.s_net.cnv0.cnv1.blk_bb.affine.W(), net.s_net.cnv0.cnv1.cnv_torch.weight)
    print_diff_summary(net.s_net.cnv0.cnv1.blk_bb.affine.b(), net.s_net.cnv0.cnv1.cnv_torch.bias)
    print_diff_summary(net.s_net.cnv1.cnv0.blk_bb.affine.W(), net.s_net.cnv1.cnv0.cnv_torch.weight)
    print_diff_summary(net.s_net.cnv1.cnv0.blk_bb.affine.b(), net.s_net.cnv1.cnv0.cnv_torch.bias)
    print_diff_summary(net.s_net.cnv1.cnv1.blk_bb.affine.W(), net.s_net.cnv1.cnv1.cnv_torch.weight)
    print_diff_summary(net.s_net.cnv1.cnv1.blk_bb.affine.b(), net.s_net.cnv1.cnv1.cnv_torch.bias)

In [None]:
net = MipmapNetwork(loops=loops, depth=depth)
if use_bb:
    net.set_input_shape([1, img_h*rows, img_w*cols])
    net.send_command("binary true")

In [None]:
if False:
    #W_bb = net.m_net.cnv0.cnv0.blk_bb.affine.W().numpy().reshape(-1)
    W_bb = net.m_net.cnv0.cnv0.blk_bb.W().numpy().reshape(-1)
    W_torch = net.s_net.cnv0.cnv0.cnv_torch.weight.to('cpu').clone().detach().numpy().reshape(-1)

    plt.figure(figsize=(10,4))
    plt.subplot(121)
    plt.title('BinaryBrain')
    plt.hist(W_bb, bins=30)
    plt.subplot(122)
    plt.title('PyTorch')
    plt.hist(W_torch, bins=30)

In [None]:
def copy_param(model):
    affine_bb    = model.blk_bb.affine
    bn_bb        = model.blk_bb.bn
    affine_torch = model.cnv_torch
    bn_torch     = model.bn_torch
    affine_bb.W().set_numpy(to_numpy(affine_torch.weight).reshape(affine_bb.W().get_shape()))
    affine_bb.b().set_numpy(to_numpy(affine_torch.bias).reshape(affine_bb.b().get_shape()))
    bn_bb.running_mean().set_numpy(to_numpy(bn_torch.running_mean).reshape(bn_bb.running_mean().get_shape()))
    bn_bb.running_var().set_numpy(to_numpy(bn_torch.running_var).reshape(bn_bb.running_var().get_shape()))
    bn_bb.gamma().set_numpy(to_numpy(bn_torch.weight).reshape(bn_bb.gamma().get_shape()))
    bn_bb.beta().set_numpy(to_numpy(bn_torch.bias).reshape(bn_bb.beta().get_shape()))

In [None]:
def compare_param(model, grad=False):
    affine_bb    = model.blk_bb.affine
    bn_bb        = model.blk_bb.bn
    affine_torch = model.cnv_torch
    bn_torch     = model.bn_torch
    print_diff(affine_bb.W(), affine_torch.weight, text="W")
    print_diff(affine_bb.b(), affine_torch.bias, text="b")
    if grad:
        print_diff(affine_bb.dW(), affine_torch.weight.grad, text="dW")
        print_diff(affine_bb.db(), affine_torch.bias.grad, text="db")
    print_diff(bn_bb.running_mean(), bn_torch.running_mean, text="running_mean")
    print_diff(bn_bb.running_var(), bn_torch.running_var, text="running_var")
    print_diff(bn_bb.gamma(), bn_torch.weight, text="gamma")
    print_diff(bn_bb.beta(),  bn_torch.bias, text="beta")
    if grad:
        print_diff(bn_bb.dgamma(), bn_torch.weight.grad, text="dgamma")
        print_diff(bn_bb.dbeta(),  bn_torch.bias.grad, text="dbeta")

In [None]:
if False:
    copy_param(net.m_net.cnv0.cnv0)
    copy_param(net.m_net.cnv0.cnv1)
    copy_param(net.m_net.cnv1.cnv0)
    copy_param(net.m_net.cnv1.cnv1)
    copy_param(net.s_net.cnv0.cnv0)
    copy_param(net.s_net.cnv0.cnv1)
    copy_param(net.s_net.cnv1.cnv0)
    copy_param(net.s_net.cnv1.cnv1)

In [None]:
if False:
    compare_param(net.m_net.cnv0.cnv0)
    compare_param(net.m_net.cnv0.cnv1)
    compare_param(net.m_net.cnv1.cnv0)
    compare_param(net.m_net.cnv1.cnv1)
    compare_param(net.s_net.cnv0.cnv0)
    compare_param(net.s_net.cnv0.cnv1)
    compare_param(net.s_net.cnv1.cnv0)
    compare_param(net.s_net.cnv1.cnv1)

### 学習実施

In [None]:
#if use_bb:
#    net.set_input_shape([1, img_h*rows, img_w*cols])
#    net.send_command("binary true")

if use_torch:
    weight_torch    = torch.from_numpy(weight.astype(np.float32)).clone().to(device)
    criterion_torch = nn.CrossEntropyLoss(weight=weight_torch)  # 面積に応じて重み付けする
    optimizer_torch = optim.Adam(net.parameters(), lr=0.001)

if use_bb:
    criterion_bb = bb.LossSoftmaxCrossEntropy()
    metrics_bb   = bb.MetricsCategoricalAccuracy()
    optimizer_bb = bb.OptimizerAdam(learning_rate=0.001)
    optimizer_bb.set_variables(net.get_parameters(), net.get_gradients())

    criterion_bb.clear()
    metrics_bb.clear()
    net.clear()

log_loss_bb    = []
log_loss_torch = []

for epoch in range(epochs):
    # learning
    if use_bb:
        criterion_bb.clear()
        metrics_bb.clear()
        net.clear()
    
    with tqdm(loader_train) as tqdm_loadr:
        for x_torch, t_torch in tqdm_loadr:
            if use_torch:
                optimizer_torch.zero_grad()
            net.clear()
            
            x_torch[x_torch >  0.5] = +1
            x_torch[x_torch <= 0.5] = -1
            
            x = Buffer()
            t = Buffer()
            if use_bb:
                x.bb = bb.FrameBuffer.from_numpy(np.array(x_torch).astype(np.float32)).astype(bin_dtype)
                t.bb = bb.FrameBuffer.from_numpy(np.array(t_torch).astype(np.float32))
            
            if use_torch:
                x.torch = x_torch.to(device)
                t.torch = t_torch.to(device)
            
            y_list = net.forward(x, train=True)
            
            if use_torch:
                y_torch_list = []
                for y in y_list:
                    y_torch_list.append(y.torch)
                yy_torch = torch.cat(y_torch_list, 0)
                tt_torch = torch.cat([t.torch]*loops, 0)
                loss_torch = criterion_torch(yy_torch, torch.argmax(tt_torch, dim=1))
                loss_torch.backward()
            
            if use_bb:
                dy_list = []
                for y in y_list:
                    dy = criterion_bb.calculate(y.bb.astype(bb.DType.FP32), t.bb)
                    dy /= len(y_list)
                    dy_list.append(dy)
                net.backward(dy_list)
#           break
            
            if use_torch:
                optimizer_torch.step()
            if use_bb:
                optimizer_bb.update()
            
            loss_val_torch = 0
            loss_val_bb    = 0
            
            if use_torch:
                loss_val_torch = loss_torch.item()
            if use_bb:
                loss_val_bb    = criterion_bb.get()
                criterion_bb.clear()
            
            log_loss_torch.append(loss_val_torch)
            log_loss_bb.append(loss_val_bb)
            
            tqdm_loadr.set_postfix(loss_bb=loss_val_bb, loss_torch=loss_val_torch)
    
#   break
    bb.save_networks(data_path, net, backups=3)
#   view(net, loader_test, n=2)

In [None]:
-------------------

In [None]:
if True:
    print('m_net.cnv0.cnv0')
    compare_param(net.m_net.cnv0.cnv0, grad=True)
    print('m_net.cnv0.cnv1')
    compare_param(net.m_net.cnv0.cnv1, grad=True)
    print('m_net.cnv1.cnv0')
    compare_param(net.m_net.cnv1.cnv0, grad=True)
    print('m_net.cnv1.cnv1')
    compare_param(net.m_net.cnv1.cnv1, grad=True)
    print('s_net.cnv0.cnv0')
    compare_param(net.s_net.cnv0.cnv0, grad=True)
    print('s_net.cnv0.cnv1')
    compare_param(net.s_net.cnv0.cnv1, grad=True)
    print('s_net.cnv1.cnv0')
    compare_param(net.s_net.cnv1.cnv0, grad=True)
    print('s_net.cnv1.cnv1')
    compare_param(net.s_net.cnv1.cnv1, grad=True)

In [None]:
-------------

In [None]:
with torch.no_grad():
    for x_torch, t_torch in tqdm_loadr:
        x_np = np.array(x_torch).astype(np.float32)
        t_np = np.array(t_torch).astype(np.float32)
        
        x = Buffer()
        t = Buffer()
        if use_bb:
            x.bb = bb.FrameBuffer.from_numpy(x_np)
            t.bb = bb.FrameBuffer.from_numpy(t_np)

        if use_torch:
            x.torch = x_torch.to(device)
            t.torch = t_torch.to(device)

        y_list = net.forward(x, train=False)
        if use_torch:
            y_torch = y_list[-1].torch.to('cpu').detach().numpy()
        if use_bb:
            y_bb    = y_list[-1].bb.numpy()
        break

In [None]:
if use_bb:
    print('BinaryBrain')
    plot_data(x_np, y_bb, 8)

In [None]:
plot_data(x_np, y_bb, 2)

In [None]:
if use_torch:
    print('PyTorch')
    plot_data(x, y_torch, 8)

In [None]:
if use_torch:
    plt.plot(log_loss_torch, label='PyTorch')
if use_bb:
    plt.plot(log_loss_bb, label='BinaryBrain')
plt.ylabel('loss')
plt.legend()

In [None]:
-------

In [None]:
print_diff_summary(net.m_net.cnv0.cnv0.blk_bb.W(), net.m_net.cnv0.cnv0.cnv_torch.weight)

In [None]:
net.m_net.cnv0.cnv0.bn_torch.running_mean

In [None]:
def copy_affine_param(model_bb, model_torch):
    model_bb.W().set_numpy(to_numpy(model_torch.weight).reshape(model_bb.W().get_shape()))
    model_bb.b().set_numpy(to_numpy(model_torch.bias).reshape(model_bb.b().get_shape()))

In [None]:
net.m_net.cnv0.cnv0.blk_bb.running_mean().numpy()

In [None]:
net.m_net.cnv0.cnv0.bn_torch.weight

In [None]:
bb.save_networks(data_path, net, backups=3)

In [None]:
plt.plot(log_loss_torch, label='batch_norm:ON')
plt.plot(log_loss_bb, label='batch_norm:OFF')
plt.ylabel('loss')
plt.legend()

In [None]:
print_diff_summary(y_list[0][1], y_list[0][0], text='y0')
print_diff_summary(y_list[1][1], y_list[1][0], text='y1')
#print_diff_summary(y_list[2][1], y_list[2][0], text='y2')
#print_diff_summary(y_list[3][1], y_list[3][0], text='y3')

In [None]:
print_diff_summary(net.s_net.cnv0.cnv0.affine_bb.dW(), net.s_net.cnv0.cnv0.cnv_torch.weight.grad)

In [None]:
print_diff_summary(net.m_net.cnv0.cnv0.affine_bb.dW(), net.m_net.cnv0.cnv0.cnv_torch.weight.grad)
print_diff_summary(net.m_net.cnv0.cnv0.affine_bb.db(), net.m_net.cnv0.cnv0.cnv_torch.bias.grad)
print_diff_summary(net.m_net.cnv0.cnv1.affine_bb.dW(), net.m_net.cnv0.cnv1.cnv_torch.weight.grad)
print_diff_summary(net.m_net.cnv0.cnv1.affine_bb.db(), net.m_net.cnv0.cnv1.cnv_torch.bias.grad)
print_diff_summary(net.m_net.cnv1.cnv0.affine_bb.dW(), net.m_net.cnv1.cnv0.cnv_torch.weight.grad)
print_diff_summary(net.m_net.cnv1.cnv0.affine_bb.db(), net.m_net.cnv1.cnv0.cnv_torch.bias.grad)
print_diff_summary(net.m_net.cnv1.cnv1.affine_bb.dW(), net.m_net.cnv1.cnv1.cnv_torch.weight.grad)
print_diff_summary(net.m_net.cnv1.cnv1.affine_bb.db(), net.m_net.cnv1.cnv1.cnv_torch.bias.grad)

print_diff_summary(net.s_net.cnv0.cnv0.affine_bb.dW(), net.s_net.cnv0.cnv0.cnv_torch.weight.grad)
print_diff_summary(net.s_net.cnv0.cnv0.affine_bb.db(), net.s_net.cnv0.cnv0.cnv_torch.bias.grad)
print_diff_summary(net.s_net.cnv0.cnv1.affine_bb.dW(), net.s_net.cnv0.cnv1.cnv_torch.weight.grad)
print_diff_summary(net.s_net.cnv0.cnv1.affine_bb.db(), net.s_net.cnv0.cnv1.cnv_torch.bias.grad)
print_diff_summary(net.s_net.cnv1.cnv0.affine_bb.dW(), net.s_net.cnv1.cnv0.cnv_torch.weight.grad)
print_diff_summary(net.s_net.cnv1.cnv0.affine_bb.db(), net.s_net.cnv1.cnv0.cnv_torch.bias.grad)
print_diff_summary(net.s_net.cnv1.cnv1.affine_bb.dW(), net.s_net.cnv1.cnv1.cnv_torch.weight.grad)
print_diff_summary(net.s_net.cnv1.cnv1.affine_bb.db(), net.s_net.cnv1.cnv1.cnv_torch.bias.grad)